【AI】Pytorch实现IoU

只要记住一点:

不论两个框长什么样,交集左上角的x和y一定取max()
                                    交集右下角的x和y一定取min()

 

【AI】Pytorch实现IoU_第1张图片

  

【AI】Pytorch实现IoU_第2张图片

代码比较简单,要注意的是,需要考虑输入的坐标格式。 

import torch

def cal_iou(box_pre, box_label,cor_format):
    # box_pre.shape = [N,4], box_label.shape = [1,4]
    
    assert cor_format == 'center' or cor_format == 'corner', 'wrong cor_format'
    
    # 坐标格式为 (x,y,w,h)
    if cor_format == 'center': 
        box1_x1 = box_pre[..., 0:1] - box_pre[..., 2:3] / 2 # ...指, 把N个box的坐标拿出来; 0:1指取出第1列,也就是x1
        box1_y1 = box_pre[..., 1:2] - box_pre[..., 3:4] / 2
        box1_x2 = box_pre[..., 0:1] + box_pre[..., 2:3] / 2
        box1_y2 = box_pre[..., 1:2] + box_pre[..., 3:4] / 2
        
        box2_x1 = box_label[..., 0:1] - box_label[..., 2:3] / 2
        box2_y1 = box_label[..., 1:2] - box_label[..., 3:4] / 2
        box2_x2 = box_label[..., 0:1] + box_label[..., 2:3] / 2
        box2_y2 = box_label[..., 1:2] + box_label[..., 3:4] / 2 

    # 坐标格式为 (x1,y1,x2,y2)
    if cor_format == 'corner': 
        box1_x1 = box_pre[..., 0:1] 
        box1_y1 = box_pre[..., 1:2]
        box1_x2 = box_pre[..., 2:3]
        box1_y2 = box_pre[..., 3:4]

        box2_x1 = box_label[..., 0:1]
        box2_y1 = box_label[..., 1:2]
        box2_x2 = box_label[..., 2:3]
        box2_y2 = box_label[..., 3:4]

    x1 = torch.max(box1_x1,box2_x1)
    y1 = torch.max(box1_y1,box2_y1)
    x2 = torch.min(box1_x2,box2_x2)
    y2 = torch.min(box1_y2,box2_y2)

    inter = (x2-x1).clamp(0) * (y2-y1).clamp(0) # .clamp(0)指,如果小于0,就取0,用以两个框没有重叠的情况

    box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
    box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)

    iou = inter / (box1_area + box2_area + 1e-6) # 加上一个很小的数,防止分母为0
    return iou

box_pre = torch.tensor([[0,0,2,2],[0,1,2,3],[1,0,3,2],[7,8,9,10]])
box_label = torch.tensor([[1,1,3,3]])

iou = cal_iou(box_pre, box_label, 'corner')
print(iou) 

>>>
tensor([[0.1250],
        [0.2500],
        [0.2500],
        [0.0000]])
正确 √

 以上。

你可能感兴趣的:(机器学习/深度学习,人工智能,pytorch,深度学习,目标检测,计算机视觉)