CIOU的公式表述如下:
各部分表述如下:
p²(b,bgt):b,bgt分别代表预测框和真实框的中心点 两点的欧式距离
c:预测框与真实框最小外接矩形的对角线距离
wgt:真实框的宽 hgt:真实框的高 w:预测框的宽 h:预测框的高
下面通过python实现:
需要传入预测框和真实框的信息:
def box_ciou(b1,b2):
# b1(预测框)/b2(真实框):(batch,4)
# 4中的信息为中心宽高的形式
求出预测框和真实框的左上角右下角:
# 求出预测框左上角右下角
b1_xy = b1[..., :2]
b1_wh = b1[..., 2:4]
b1_wh_half = b1_wh/2.
b1_mins = b1_xy - b1_wh_half
b1_maxes = b1_xy + b1_wh_half
# 求出真实框左上角右下角
b2_xy = b2[..., :2]
b2_wh = b2[..., 2:4]
b2_wh_half = b2_wh/2.
b2_mins = b2_xy - b2_wh_half
b2_maxes = b2_xy + b2_wh_half
求出iou:
# 求真实框和预测框所有的iou
# (batch, 2)
intersect_mins = torch.max(b1_mins, b2_mins)
intersect_maxes = torch.min(b1_maxes, b2_maxes)
# 避免出现负数,与0相比
# (batch, 2)
intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
# iou面积 (batch,1)
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
union_area = b1_area + b2_area - intersect_area
iou = intersect_area / torch.clamp(union_area,min = 1e-6)
计算真实框和预测框的中心位置的欧式距离p²(b,bgt):
# (batch,1)
center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
计算真实框和预测框最大外接矩形的对角线距离c:
enclose_mins = torch.min(b1_mins, b2_mins)
enclose_maxes = torch.max(b1_maxes, b2_maxes)
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
# 计算对角线距离(batch,1)
enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
diou可以知道为:
diou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
接下莱求出av及ciou:
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0]/torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0]/torch.clamp(b2_wh[..., 1],min = 1e-6))), 2)
alpha = v / torch.clamp((1.0 - iou + v),min=1e-6)
ciou = diou - alpha * v
完整代码:
def box_ciou(b1, b2):
"""
b1(预测框)/b2(真实框):(batch,4)
iou:主要考虑检测框和目标框的重叠面积
Giou:在IOU的基础上,解决了边界框不重合时的问题
Diou:在以上基础上,考虑边界框中心点距离的信息
Ciou:在Diou基础上,考虑边界框宽高比的尺度信息
ciou = iou - p²(b,bgt)/c² -av
p²(b,bgt):b,bgt分别代表预测框和真实框的中心点 两点的欧式距离
c:预测框与真实框最小外接矩形的对角线距离
v= 4/π²(arctan(wgt/hgt)-arctan(w/h))²,a = v/((1-iou)+v) v用来度量长宽比的相似性
wgt:真实框的宽 hgt:真实框的高 w:预测框的宽 h:预测框的高
"""
# 求出预测框左上角右下角
b1_xy = b1[..., :2]
b1_wh = b1[..., 2:4]
b1_wh_half = b1_wh/2.
b1_mins = b1_xy - b1_wh_half
b1_maxes = b1_xy + b1_wh_half
# 求出真实框左上角右下角
b2_xy = b2[..., :2]
b2_wh = b2[..., 2:4]
b2_wh_half = b2_wh/2.
b2_mins = b2_xy - b2_wh_half
b2_maxes = b2_xy + b2_wh_half
# 求真实框和预测框所有的iou
# (batch, 2)
intersect_mins = torch.max(b1_mins, b2_mins)
intersect_maxes = torch.min(b1_maxes, b2_maxes)
# 避免出现负数,与0相比
# (batch, 2)
intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
# iou面积 (batch,1)
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
union_area = b1_area + b2_area - intersect_area
iou = intersect_area / torch.clamp(union_area,min = 1e-6)
# 计算中心的差距 p²(b,bgt)
# (batch,1)
center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
# 找到包裹两个框的最小框的左上角和右下角 c
enclose_mins = torch.min(b1_mins, b2_mins)
enclose_maxes = torch.max(b1_maxes, b2_maxes)
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
# 计算对角线距离(batch,1)
enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
diou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0]/torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0]/torch.clamp(b2_wh[..., 1],min = 1e-6))), 2)
alpha = v / torch.clamp((1.0 - iou + v),min=1e-6)
ciou = diou - alpha * v
# (batch,1)
return ciou