YOLOv4loss重写()

YOLOv4_CIOU函数


该文章仅记录自己学习历程,若有问题还请指正.
yolov4中使用的是CIOU计算loss,若对于IOU的各种方法不明白的还请看下这个视频 :IOU计算

CIOU的公式表述如下:

YOLOv4loss重写()_第1张图片
各部分表述如下:
YOLOv4loss重写()_第2张图片
YOLOv4loss重写()_第3张图片
p²(b,bgt):b,bgt分别代表预测框和真实框的中心点 两点的欧式距离
c:预测框与真实框最小外接矩形的对角线距离
wgt:真实框的宽 hgt:真实框的高 w:预测框的宽 h:预测框的高

下面通过python实现:

需要传入预测框和真实框的信息:

def  box_ciou(b1,b2):
 #  b1(预测框)/b2(真实框):(batch,4)
 #  4中的信息为中心宽高的形式

求出预测框和真实框的左上角右下角:

 # 求出预测框左上角右下角
    b1_xy = b1[..., :2]
    b1_wh = b1[..., 2:4]
    b1_wh_half = b1_wh/2.
    b1_mins = b1_xy - b1_wh_half
    b1_maxes = b1_xy + b1_wh_half
    # 求出真实框左上角右下角
    b2_xy = b2[..., :2]
    b2_wh = b2[..., 2:4]
    b2_wh_half = b2_wh/2.
    b2_mins = b2_xy - b2_wh_half
    b2_maxes = b2_xy + b2_wh_half

求出iou:

# 求真实框和预测框所有的iou
    #  (batch, 2)
    intersect_mins = torch.max(b1_mins, b2_mins)
    intersect_maxes = torch.min(b1_maxes, b2_maxes)
    # 避免出现负数,与0相比
    # (batch, 2)
    intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
    #  iou面积 (batch,1)
    intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
    b1_area = b1_wh[..., 0] * b1_wh[..., 1]
    b2_area = b2_wh[..., 0] * b2_wh[..., 1]
    union_area = b1_area + b2_area - intersect_area
    iou = intersect_area / torch.clamp(union_area,min = 1e-6)

计算真实框和预测框的中心位置的欧式距离p²(b,bgt):

# (batch,1)
    center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)

计算真实框和预测框最大外接矩形的对角线距离c:

    enclose_mins = torch.min(b1_mins, b2_mins)
    enclose_maxes = torch.max(b1_maxes, b2_maxes)
    enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
    # 计算对角线距离(batch,1)
    enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)

diou可以知道为:

    diou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)

接下莱求出av及ciou:

    v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0]/torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0]/torch.clamp(b2_wh[..., 1],min = 1e-6))), 2)
    alpha = v / torch.clamp((1.0 - iou + v),min=1e-6)
    ciou = diou - alpha * v

完整代码:

def box_ciou(b1, b2):
    """
    b1(预测框)/b2(真实框):(batch,4)
    iou:主要考虑检测框和目标框的重叠面积
    Giou:在IOU的基础上,解决了边界框不重合时的问题
    Diou:在以上基础上,考虑边界框中心点距离的信息
    Ciou:在Diou基础上,考虑边界框宽高比的尺度信息
    ciou = iou - p²(b,bgt)/c² -av
    p²(b,bgt):b,bgt分别代表预测框和真实框的中心点  两点的欧式距离
    c:预测框与真实框最小外接矩形的对角线距离
    v= 4/π²(arctan(wgt/hgt)-arctan(w/h))²,a = v/((1-iou)+v)  v用来度量长宽比的相似性
    wgt:真实框的宽 hgt:真实框的高 w:预测框的宽 h:预测框的高
    """
    # 求出预测框左上角右下角
    b1_xy = b1[..., :2]
    b1_wh = b1[..., 2:4]
    b1_wh_half = b1_wh/2.
    b1_mins = b1_xy - b1_wh_half
    b1_maxes = b1_xy + b1_wh_half
    # 求出真实框左上角右下角
    b2_xy = b2[..., :2]
    b2_wh = b2[..., 2:4]
    b2_wh_half = b2_wh/2.
    b2_mins = b2_xy - b2_wh_half
    b2_maxes = b2_xy + b2_wh_half

    # 求真实框和预测框所有的iou
    #  (batch, 2)
    intersect_mins = torch.max(b1_mins, b2_mins)
    intersect_maxes = torch.min(b1_maxes, b2_maxes)
    # 避免出现负数,与0相比
    # (batch, 2)
    intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
    #  iou面积 (batch,1)
    intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
    b1_area = b1_wh[..., 0] * b1_wh[..., 1]
    b2_area = b2_wh[..., 0] * b2_wh[..., 1]
    union_area = b1_area + b2_area - intersect_area
    iou = intersect_area / torch.clamp(union_area,min = 1e-6)

    # 计算中心的差距  p²(b,bgt)
    # (batch,1)
    center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
    
    # 找到包裹两个框的最小框的左上角和右下角  c
    enclose_mins = torch.min(b1_mins, b2_mins)
    enclose_maxes = torch.max(b1_maxes, b2_maxes)
    enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
    # 计算对角线距离(batch,1)
    enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
    diou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
    
    v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0]/torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0]/torch.clamp(b2_wh[..., 1],min = 1e-6))), 2)
    alpha = v / torch.clamp((1.0 - iou + v),min=1e-6)
    ciou = diou - alpha * v
    # (batch,1)
    return ciou

你可能感兴趣的:(python,YOLO,python)