pytorch-NMS快速上手

NMS 非极大值抑制

理论知识请自行百度 此处只用pytorch快速完成代码

虚拟数据

import numpy as np
import torch
boxes=np.array([[100,100,210,210,0.72],
        [250,250,420,420,0.8],
        [220,220,320,330,0.92],
        [100,100,210,210,0.72],
        [230,240,325,330,0.81],
        [220,230,315,340,0.9]]) 
indices = boxes[:,4].sort(descending=True) #置信度 依靠他排序

计算体


class Calculator(object):
    def __init__(self) -> None:
        super().__init__()
	#数据处理
    def getPruneRes(self,first,boxes,indices,thre):
        self.first = first
        self.indices = indices
        boxes = torch.FloatTensor(boxes)
        self.x = boxes[:,0]
        self.y = boxes[:,1]
        self.xx = boxes[:,2]
        self.yy = boxes[:,3]
        self.scores = boxes[:,4]
        self.areas = (self.xx-self.x)*(self.yy-self.y)
        self.thre = thre
        return torch.tensor(list(map(self.cal,self.indices)),requires_grad=False)
	#核心计算 clamp是设定阈值,就是锁定相交部分正方形的四个点信息
    def cal(self,indice):
        x = torch.clamp(self.x[indice],min=self.x[self.first])
        y = torch.clamp(self.y[indice],min=self.y[self.first])
        xx = torch.clamp(self.xx[indice],max=self.xx[self.first])
        yy = torch.clamp(self.yy[indice],max=self.yy[self.first])
        #IOU
        area = (xx-x).clamp(min=0)*(yy-y).clamp(min=0)
        #小于阈值就返回False
        if(area/(self.areas[indice]+self.areas[self.first]-area)) < self.thre:
            return True
        return False

最后调用getPruneRes会获得布尔值的数组,用它去把indices过滤掉就行了
然后跑循环即可


calculator = Calculator()
res=[]
while(indices.numel()>0):
    first = indices[0]
    res.append(first.item())
    indices = indices[1:]
    if(indices.numel()==0):
        break
    mask = calculator.getPruneRes(first,boxes,indices,thre=0.5)
    indices = indices.masked_select(mask)
print(res)
#[2, 5, 4]

你可能感兴趣的:(深度学习,pytorch,python,pytorch,深度学习,神经网络,图像识别)