NMS-常规NMS和DIOU NMS实际效果

本人在github上的YOLOv3(pytorch版本)的后处理过程中,使用DIOU NMS代替常规的NMS过程(即IOU的计算替换为DIOU),实际差别如下(数据集仅有一个类别,内河航道船舶):

NMS:
	Average Precisions:
	+ Class '0' (boat) - AP: 0.9324305962764635
	mAP: 0.9324305962764635

DIOU NMS:
	Average Precisions:
	+ Class '0' (boat) - AP: 0.9284645071187682
	mAP: 0.9284645071187682

实际来看DIOU NMS效果不如常规NMS。因此在实际使用过程中,需要具体问题具体分析。
ps:感觉DIOU NMS的实现应该就是把IOU换为DIOU吧…不太确定,上面的结果是按照这种思路来的。


IOU和DIOU代码:

def bbox_iou(box1, box2, x1y1x2y2=True, diou=False):
    """
    Returns the IoU of two bounding boxes
    """
    if not x1y1x2y2:
        # Transform from center and width to exact coordinates
        b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
        b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
    else:
        # Get the coordinates of bounding boxes
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]

    # get the corrdinates of the intersection rectangle
    inter_rect_x1 = torch.max(b1_x1, b2_x1)
    inter_rect_y1 = torch.max(b1_y1, b2_y1)
    inter_rect_x2 = torch.min(b1_x2, b2_x2)
    inter_rect_y2 = torch.min(b1_y2, b2_y2)
    # Intersection area
    inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp(
        inter_rect_y2 - inter_rect_y1 + 1, min=0
    )
    # Union Area
    b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
    b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
    if diou:
        union = b1_area + b2_area - inter_area

        # compute enclose area
        enclose_l = torch.min(b1_x1, b2_x1)
        enclose_t = torch.min(b1_y1, b2_y1)
        enclose_r = torch.max(b1_x2, b2_x2)
        enclose_b = torch.max(b1_y2, b2_y2)
        enclose_w = enclose_r - enclose_l
        enclose_h = enclose_b - enclose_t
        enclose_diag = torch.pow(enclose_w, 2) + torch.pow(enclose_h, 2)

        # compute center diag
        center_b1_cx = (b1_x1 + b1_x2) / 2
        center_b1_cy = (b1_y1 + b1_y2) / 2
        center_b2_cx = (b2_x1 + b2_x2) / 2
        center_b2_cy = (b2_y1 + b2_y2) / 2
        center_diag = torch.pow(center_b1_cx - center_b2_cx, 2) + torch.pow(center_b1_cy - center_b2_cy, 2)

        diou = torch.clamp(inter_area / union - center_diag / enclose_diag, min=-1., max=1.)
        return diou

    else:
        iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
        return iou

你可能感兴趣的:(深度学习)