[pytorch yolo损失函数] yolo的损失函数[alpha_iou,alpha_diou,alpha_giou,alpha_ciou,alpha_siou]

yolo目标检测,做个简单的学习记录

损失函数,参考1

参考链接,alpha版的各种损失

alpha_iou,alpha_diou,alpha_giou,alpha_ciou,alpha_siou

yolov6代码链接

import math
import torch

class IOUloss(nn.Module):
    """ Calculate IoU loss.
    """
    def __init__(self, box_format='xywh', iou_type=None, reduction='none', eps=1e-7, alpha=1):
        super().__init__()
        """ Setting of the class.
        Args:
            box_format: (string), must be one of 'xywh' or 'xyxy'.
            iou_type: (string), can be one of 'ciou', 'diou', 'giou' or 'siou'
            reduction: (string), specifies the reduction to apply to the output, must be one of 'none', 'mean','sum'.
            eps: (float), a value to avoid divide by zero error.
        """
        self.box_format = box_format
        self.iou_type = iou_type.lower()
        self.reduction = reduction
        self.eps = eps
        self.alpha = alpha
        
    def forward(self, box1, box2):
        """ calculate iou. box1 and box2 are torch tensor with shape [M, 4] and [Nm 4].
        """
        if self.box_format == 'xyxy':
            b1_x1, b1_y1, b1_x2, b1_y2 = torch.split(box1, 1, dim=-1)
            b2_x1, b2_y1, b2_x2, b2_y2 = torch.split(box2, 1, dim=-1)

        elif self.box_format == 'xywh':
            b1_x1, b1_y1, b1_w, b1_h = torch.split(box1, 1, dim=-1)
            b2_x1, b2_y1, b2_w, b2_h = torch.split(box2, 1, dim=-1)
            b1_x1, b1_x2 = b1_x1 - b1_w / 2, b1_x1 + b1_w / 2
            b1_y1, b1_y2 = b1_y1 - b1_h / 2, b1_y1 + b1_h / 2
            b2_x1, b2_x2 = b2_x1 - b2_w / 2, b2_x1 + b2_w / 2
            b2_y1, b2_y2 = b2_y1 - b2_h / 2, b2_y1 + b2_h / 2

        # Intersection area
        inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
                (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)

        # Union Area
        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + self.eps
        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + self.eps
        union = w1 * h1 + w2 * h2 - inter + self.eps
        iou = (inter / union)**self.alpha

        cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex width
        ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height
        if self.iou_type == 'giou':
            c_area = cw * ch + self.eps  # convex area
            iou = iou - ((c_area - union) / c_area)**self.alpha
        elif self.iou_type in ['diou', 'ciou']:
            c2 = cw ** 2 + ch ** 2 + self.eps  # convex diagonal squared
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
                    (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center distance squared
            if self.iou_type == 'diou':
                iou = iou - rho2**self.alpha / c2**self.alpha
            elif self.iou_type == 'ciou':
                v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + self.eps))
                iou = iou - (rho2**self.alpha / c2**self.alpha + (v * alpha)**self.alpha)
        elif self.iou_type == 'siou':
            # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
            s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + self.eps
            s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + self.eps
            sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
            sin_alpha_1 = torch.abs(s_cw) / sigma
            sin_alpha_2 = torch.abs(s_ch) / sigma
            threshold = pow(2, 0.5) / 2
            sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
            angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
            rho_x = (s_cw / cw) ** 2
            rho_y = (s_ch / ch) ** 2
            gamma = angle_cost - 2
            distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
            omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
            omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
            shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
            iou = iou - 0.5 * (distance_cost + shape_cost)
        loss = 1.0 - iou

        if self.reduction == 'sum':
            loss = loss.sum()
        elif self.reduction == 'mean':
            loss = loss.mean()

        return loss

你可能感兴趣的:(python,pytorch,pytorch,YOLO,深度学习)