def compute_loss(p, targets, model): # predictions, targets, model
device = targets.device
#创建用来保存三层特征图的损失
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
#build_targets详见https://blog.csdn.net/a1874738854/article/details/112789533
#获取gt和对应的anchor
tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
h = model.hyp # hyperparameters
# Define criteria
#分类和confidence损失函数
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
#是否对label采取平滑
cp, cn = smooth_BCE(eps=0.0)
# Focal loss
g = h['fl_gamma'] # focal loss gamma
if g > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
# Losses
nt = 0 # number of targets
#获取输出特征图的层数
no = len(p) # number of outputs
balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6
#对每个特征图进行计算损失
for i, pi in enumerate(p): # layer index, layer predictions
#获取该层特征图上的gt信息:图像序号,anchor序号,位于特征图上的格网坐标
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
#tobj存储gt中的置信度真值
tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
n = b.shape[0] # number of targets
if n:#有gt才计算分类和回归损失,否则只计算置信度损失
nt += n # cumulative targets
#获取真值对应的预测值box信息
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
# Regression 对预测值进行预处理
pxy = ps[:, :2].sigmoid() * 2. - 0.5
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
#计算CIOU
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
#box坐标回归损失
lbox += (1.0 - iou).mean() # iou loss
# Objectness
#利用IOU对gt中的置信度进行加权(对应与build_targets中的gt扩充)
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
# Classification
计算分类损失
if model.nc > 1: # cls loss (only if multiple classes)
#label smooth
t = torch.full_like(ps[:, 5:], cn, device=device) # targets
t[range(n), tcls[i]] = cp
lcls += BCEcls(ps[:, 5:], t) # BCE
# Append targets to text file
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
#获取置信度损失
lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
s = 3 / no # output count scaling
lbox *= h['box'] * s
lobj *= h['obj'] * s * (1.4 if no == 4 else 1.)
lcls *= h['cls'] * s
bs = tobj.shape[0] # batch size
loss = lbox + lobj + lcls
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()