class IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="iou"):
super(IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type
def forward(self, pred, target):
assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_p = torch.prod(pred[:, 2:], 1)
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
area_u = area_p + area_g - area_i
iou = (area_i) / (area_u + 1e-16)
if self.loss_type == "iou":
loss = 1 - iou ** 2
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)
# 尝试加入diou,ciou
elif self.loss_type == 'diou':
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
# 最大外界矩形对角线长度c^2
w_c = (c_br - c_tl)[:, 0]
h_c = (c_br - c_tl)[:, 1]
c = w_c ** 2 + h_c ** 2
# 中心点距离平方d^2
w_d = (pred[:, :2] - target[:, :2])[:, 0]
h_d = (pred[:, :2] - target[:, :2])[:, 1]
d = w_d ** 2 + h_d ** 2
# 求diou
diou = iou - d / c
loss = 1 - diou.clamp(min=-1.0, max=1.0)
elif self.loss_type == 'ciou':
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
# 最大外界矩形对角线长度c^2
w_c = (c_br - c_tl)[:, 0]
h_c = (c_br - c_tl)[:, 1]
c = w_c ** 2 + h_c ** 2
# 中心点距离平方d^2
w_d = (pred[:, :2] - target[:, :2])[:, 0]
h_d = (pred[:, :2] - target[:, :2])[:, 1]
d = w_d ** 2 + h_d ** 2
# 求diou
diou = iou - d / c
w_gt = target[:, 2]
h_gt = target[:, 3]
w = pred[:, 2]
h = pred[:, 3]
with torch.no_grad():
arctan = torch.atan(w_gt / h_gt) - torch.atan(w / h)
v = (4 / (math.pi ** 2)) * torch.pow(arctan, 2)
s = 1 - iou
alpha = v / (s + v)
ciou = diou - alpha * v
loss = 1 - ciou.clamp(min=-1.0, max=1.0)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss
self.iou_loss = IOUloss(reduction="none")
现为
self.iou_loss = IOUloss(reduction="none", loss_type="giou")
self.iou_loss = IOUloss(reduction="none", loss_type="diou")
self.iou_loss = IOUloss(reduction="none", loss_type="ciou")
```资料来源于网络