YOLO v2目标检测详解三 去除无效数据

在从文件读入标注的数据时,会把物体数量向物体最多的那张图补齐,补齐的时候会添加进不少无效的框,最后计算的时候需要将这部分无效数据去除,添加的无效数据为(0,0,0,0),现在需要将这部分数据去掉

#把添加的无效数据去除
def gt_mask_from_gts(gts):
    gt_stk = gts.view(-1, 4)
    invalid_gt = torch.Tensor([0, 0, 0, 0])
    if CAN_USE_GPU:
        invalid_gt = invalid_gt.cuda()
    gt_mask = torch.zeros(size=(gt_stk.shape[0], ))
    gt_mask[gt_stk.eq(invalid_gt.view(1, 4)).sum(1) != 4] = 1
    return gt_mask.view(gts.shape[0], gts.shape[1])

部分无效的iou也需要同样去掉

YOLO v2目标检测详解三 去除无效数据_第1张图片

如图所示,红色框的中心是在第二个黑色框中,那么第一个黑色的框对应的anchor是不需要和红色的框计算iou的,那么只需要保留第二个anchor和红色框的iou,那么可以用对应的tensor来表示,如:0,0,0,0,0,1,1,1,1,1就表示取第二个框对应的anchor和红色框计算iou

def range_mask_from_gts(gts, w_n, anchor_num, cell_anchor_num=5):
    batch_size, gt_num = gts.shape[0], gts.shape[1]
    gt_stk = gts.view(-1, 4)
    xmin, ymin, xmax, ymax = gt_stk[:, 0], gt_stk[:, 1], gt_stk[:, 2], gt_stk[:, 3]
    x, y = ((xmin+xmax+1)/64).int(), ((ymin+ymax+1)/64).int()
    start = (y*w_n+x)*cell_anchor_num
    range_mask = torch.zeros(batch_size*gt_num, anchor_num)
    for i in range(start.shape[0]):
        range_mask[i, start[i].item():start[i].item()+cell_anchor_num] = 1
    return range_mask.view(batch_size, gt_num, anchor_num)

 

你可能感兴趣的:(pytorch,机器学习)