yolov5代码详解-compute_loss(p, targets, model)

def compute_loss(p, targets, model):  # predictions, targets, model
    device = targets.device
    #创建用来保存三层特征图的损失
    lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
    #build_targets详见https://blog.csdn.net/a1874738854/article/details/112789533
    #获取gt和对应的anchor
    tcls, tbox, indices, anchors = build_targets(p, targets, model)  # targets
    h = model.hyp  # hyperparameters

    # Define criteria
    #分类和confidence损失函数
    BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)
    BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)

    # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
    #是否对label采取平滑
    cp, cn = smooth_BCE(eps=0.0)

    # Focal loss
    g = h['fl_gamma']  # focal loss gamma
    if g > 0:
        BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

    # Losses
    nt = 0  # number of targets
    #获取输出特征图的层数
    no = len(p)  # number of outputs
    balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1]  # P3-5 or P3-6
    #对每个特征图进行计算损失
    for i, pi in enumerate(p):  # layer index, layer predictions
        #获取该层特征图上的gt信息:图像序号,anchor序号,位于特征图上的格网坐标
        b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
        #tobj存储gt中的置信度真值
        tobj = torch.zeros_like(pi[..., 0], device=device)  # target obj

        n = b.shape[0]  # number of targets
        if n:#有gt才计算分类和回归损失,否则只计算置信度损失
            nt += n  # cumulative targets
            #获取真值对应的预测值box信息
            ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets
            
            # Regression 对预测值进行预处理
            pxy = ps[:, :2].sigmoid() * 2. - 0.5
            pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
            pbox = torch.cat((pxy, pwh), 1).to(device)  # predicted box
            #计算CIOU
            iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  # iou(prediction, target)
            #box坐标回归损失
            lbox += (1.0 - iou).mean()  # iou loss

            # Objectness
            #利用IOU对gt中的置信度进行加权(对应与build_targets中的gt扩充)
            tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype)  # iou ratio

            # Classification
            计算分类损失
            if model.nc > 1:  # cls loss (only if multiple classes)
                #label smooth
                t = torch.full_like(ps[:, 5:], cn, device=device)  # targets
                t[range(n), tcls[i]] = cp
                lcls += BCEcls(ps[:, 5:], t)  # BCE

            # Append targets to text file
            # with open('targets.txt', 'a') as file:
            #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
        #获取置信度损失
        lobj += BCEobj(pi[..., 4], tobj) * balance[i]  # obj loss

    s = 3 / no  # output count scaling
    lbox *= h['box'] * s
    lobj *= h['obj'] * s * (1.4 if no == 4 else 1.)
    lcls *= h['cls'] * s
    bs = tobj.shape[0]  # batch size

    loss = lbox + lobj + lcls
    return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()

 

你可能感兴趣的:(yolov5代码详解-compute_loss(p, targets, model))