NMS计算代码

'''
# INPUT:所有预测出的bounding box (bbx)信息(坐标和置信度confidence), IOU阈值(大于该阈值的bbx将被移除)
for object in all objects:
   (1) 获取当前目标类别下所有bbx的信息
   (2) 将bbx按照confidence从高到低排序,并记录当前confidence最大的bbx
   (3) 计算最大confidence对应的bbx与剩下所有的bbx的IOU,移除所有大于IOU阈值的bbx
   (4) 对剩下的bbx,循环执行(2)和(3)直到所有的bbx均满足要求(即不能再移除bbx)
'''

import numpy as np

def non_max_suppress(predicts_dict, threshhold=0.2):
    '''
    :param predicts_dict: {'分类1':[[Xmin, Ymin, Xmax, Ymax, Score], [...]], '分类2':[[...]]}
    :param threshhold: suprress threshhold
    :return:
    '''

    for object_name, bbox in predicts_dict.items():
        # list to array
        bbox_array = np.array(bbox, dtype=np.float)

        # get coordinates
        Xmin, Ymin, Xmax, Ymax, scores = bbox_array[:, 0], bbox_array[:, 1], bbox_array[:, 2], bbox_array[:, 3], bbox_array[:, 4],

        # 获得每个bbox的面积,用于计算并集
        bbox_area = (Xmax - Xmin) * (Ymax - Ymin)

        # 按score降序排序的bbox索引
        order = scores.argsort()[::-1]

        # 最终保留的bbox索引
        keep = []

        while order.size > 0:

            # 第一个为该类别最大置信度的索引,保留
            i = order[0]
            keep.append(i)

            # 计算与其他bbox的IOU
            # 计算左上角和右下角坐标
            inter_Xmin = np.maximum(Xmin[i], Xmin[order[1:]])
            inter_Ymin = np.maximum(Ymin[i], Ymin[order[1:]])
            inter_Xmax = np.minimum(Xmax[i], Xmax[order[1:]])
            inter_Ymax = np.minimum(Ymax[i], Ymax[order[1:]])

            # 计算交集
            inter_area = np.maximum(0, inter_Xmax-inter_Xmin) * np.maximum(0, inter_Ymax - inter_Ymin)

            # 计算IOU
            IOU = inter_area / (bbox_area[i] + bbox_area[order[1:]] - inter_area + 1e-6)

            # 获取保留下来的索引(因为没有计算与自身的IOU,所以索引相差1,需要加上)
            indexs = np.where(IOU <= threshhold)[0] + 1
            order = order[indexs]

        bbox = bbox_array[keep]
        predicts_dict[object_name] = bbox.tolist()

    return predicts_dict


predicts_dict = non_max_suppress({'苹果':[[0, 0, 1, 1, 0.6], [2, 2, 4, 4, 0.8], [3, 3, 4, 4, 0.5]]})
print(predicts_dict)

 

你可能感兴趣的:(目标检测)