本人在github上的YOLOv3(pytorch版本)的后处理过程中,使用DIOU NMS代替常规的NMS过程(即IOU的计算替换为DIOU),实际差别如下(数据集仅有一个类别,内河航道船舶):
NMS:
Average Precisions:
+ Class '0' (boat) - AP: 0.9324305962764635
mAP: 0.9324305962764635
DIOU NMS:
Average Precisions:
+ Class '0' (boat) - AP: 0.9284645071187682
mAP: 0.9284645071187682
实际来看DIOU NMS效果不如常规NMS。因此在实际使用过程中,需要具体问题具体分析。
ps:感觉DIOU NMS的实现应该就是把IOU换为DIOU吧…不太确定,上面的结果是按照这种思路来的。
IOU和DIOU代码:
def bbox_iou(box1, box2, x1y1x2y2=True, diou=False):
"""
Returns the IoU of two bounding boxes
"""
if not x1y1x2y2:
# Transform from center and width to exact coordinates
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
else:
# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
# get the corrdinates of the intersection rectangle
inter_rect_x1 = torch.max(b1_x1, b2_x1)
inter_rect_y1 = torch.max(b1_y1, b2_y1)
inter_rect_x2 = torch.min(b1_x2, b2_x2)
inter_rect_y2 = torch.min(b1_y2, b2_y2)
# Intersection area
inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp(
inter_rect_y2 - inter_rect_y1 + 1, min=0
)
# Union Area
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
if diou:
union = b1_area + b2_area - inter_area
# compute enclose area
enclose_l = torch.min(b1_x1, b2_x1)
enclose_t = torch.min(b1_y1, b2_y1)
enclose_r = torch.max(b1_x2, b2_x2)
enclose_b = torch.max(b1_y2, b2_y2)
enclose_w = enclose_r - enclose_l
enclose_h = enclose_b - enclose_t
enclose_diag = torch.pow(enclose_w, 2) + torch.pow(enclose_h, 2)
# compute center diag
center_b1_cx = (b1_x1 + b1_x2) / 2
center_b1_cy = (b1_y1 + b1_y2) / 2
center_b2_cx = (b2_x1 + b2_x2) / 2
center_b2_cy = (b2_y1 + b2_y2) / 2
center_diag = torch.pow(center_b1_cx - center_b2_cx, 2) + torch.pow(center_b1_cy - center_b2_cy, 2)
diou = torch.clamp(inter_area / union - center_diag / enclose_diag, min=-1., max=1.)
return diou
else:
iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
return iou