YOLO的误差主要分为三大部分:IOU损失、分类损失、坐标损失,IOU损失分为了no_objects_loss和objects_loss。
objects:anchor_boxes与ground truth的IOU最大的框
no_objects:除去IOU最大的框都是
confidence:判断锚框内是否存在检测物体
YOLO_v1:只有一个anchor_box(缺点:只能检测一个单元格内包含一个目标的情况),将目标检测看成回归问题,采用均方差损失函数。
修正系数:no_objects_loss和objects_loss分别是0.5和1
YOLO_v2:总共有845个anchor_boxes,与true_boxes匹配的用于预测pred_boxes,未与true_boxes匹配的anchor_boxes用于预测background。
YOLO_v2和YOLO_v1基本一致,就是经过softmax()后,20维向量(数据集中分类种类为20种)的均方误差。
YOLO_v1:
def yolo_loss(args,
anchors,
num_classes,
rescore_confidence=False,
print_loss=False):
"""
参数
----------
yolo_output : 神经网络最后一层的输出,shape:[batch_size,13,13,125]
true_boxes : 实际框的位置和类别,我们的输入。三个维度:
第一个维度:图片张数
第二个维度:一张图片中有几个实际框
第三个维度: [x, y, w, h, class],x,y 是实际框的中心点坐标,w,h 是框的宽度和高度。x,y,w,h 均是除以图片分辨率得到的[0,1]范围的值。
detectors_mask : 取值是0或者1,这里的shape是[ batch_size,13,13,5,1],其值可参考函数preprocess_true_boxes()的输出,五个维度:
第一个维度:图片张数
第二个维度:true_boxes的中心位于第几行(y方向上属于第几个gird cell)
第三个维度:true_boxes的中心位于第几列(x方向上属于第几个gird cell)
第四个维度:哪个anchor box
第五个维度:0/1。1的就是用于预测改true boxes 的 anchor boxes
matching_true_boxes :这里的shape是[-1,13,13,5,5],其值可参考函数preprocess_true_boxes()的输出,五个维度:
第一个维度:图片张数
第二个维度:true_boxes的中心位于第几行(y方向上属于第几个gird cel)
第三个维度:true_boxes的中心位于第几列(x方向上属于第几个gird cel)
第四个维度:第几个anchor box
第五个维度:[x,y,w,h,class]。这里的x,y表示offset,是相当于gird cell的坐标,w,h是取了log函数的,class是属于第几类。
anchors : 实际anchor boxes 的值,论文中使用了五个。[w,h],都是相对于gird cell 长宽的比值。二个维度:
第一个维度:anchor boxes的数量,这里是5
第二个维度:[w,h],w,h,都是相对于gird cell 长宽的比值。
[1.08, 1.19], [3.42, 4.41], [6.63, 11.38], [9.42, 5.11], [16.62, 10.52]
num_classes :类别个数(有多少类)
rescore_confidence : bool值,False和True计算confidence_loss的objects_loss不同,后面代码可以看到。
print_loss : bool值,是否打印损失,包括总损失,IOU损失,分类损失,坐标损失
返回值
-------
total_loss : float,总损失
"""
(yolo_output, true_boxes, detectors_mask, matching_true_boxes) = args
num_anchors = len(anchors)
object_scale = 5 '物体位于gird cell时计算置信度的修正系数'
no_object_scale = 1 '物体位于gird cell时计算置信度的修正系数'
class_scale = 1 '计算分类损失的修正系数'
coordinates_scale = 1 '计算坐标损失的修正系数'
pred_xy, pred_wh, pred_confidence, pred_class_prob = yolo_head(
yolo_output, anchors, num_classes)
yolo_output_shape = K.shape(yolo_output)
feats = K.reshape(yolo_output, [
-1, yolo_output_shape[1], yolo_output_shape[2], num_anchors,
num_classes + 5]) 'shape:[-1,13,13,5,25]'
pred_boxes = K.concatenate(
(K.sigmoid(feats[..., 0:2]), feats[..., 2:4]), axis=-1)
'合并得到pred_boxes的x,y,w,h,用于和matching_true_boxes计算坐标损失,shape:[-1,13,13,5,4]'
# Expand pred x,y,w,h to allow comparison with ground truth.
# batch, conv_height, conv_width, num_anchors, num_true_boxes, box_params
pred_xy = K.expand_dims(pred_xy, 4) '增加一个维度由[-1,13,13,5,2]变成[-1,13,13,5,1,2]'
pred_wh = K.expand_dims(pred_wh, 4) '增加一个维度由[-1,13,13,5,2]变成[-1,13,13,5,1,2]'
pred_wh_half = pred_wh / 2.
pred_mins = pred_xy - pred_wh_half
pred_maxes = pred_xy + pred_wh_half
'计算pred_boxes左上顶点和右下顶点的坐标'
true_boxes_shape = K.shape(true_boxes)
true_boxes = K.reshape(true_boxes, [true_boxes_shape[0], 1, 1, 1, true_boxes_shape[1], true_boxes_shape[2]])
'shape:[-1,1,1,1,-1,5],batch, conv_height, conv_width, num_anchors, num_true_boxes, box_params'
true_xy = true_boxes[..., 0:2]
true_wh = true_boxes[..., 2:4]
true_wh_half = true_wh / 2.
true_mins = true_xy - true_wh_half
true_maxes = true_xy + true_wh_half
'计算true_boxes左上顶点和右下顶点的坐标'
intersect_mins = K.maximum(pred_mins, true_mins)
intersect_maxes = K.minimum(pred_maxes, true_maxes)
intersect_wh = K.maximum(intersect_maxes - intersect_mins, 0.)
intersect_areas = intersect_wh[..., 0] * intersect_wh[..., 1]
pred_areas = pred_wh[..., 0] * pred_wh[..., 1]
true_areas = true_wh[..., 0] * true_wh[..., 1]
union_areas = pred_areas + true_areas - intersect_areas
iou_scores = intersect_areas / union_areas
'计算出所有anchor boxes(这里是一张图片845个)和true_boxes的IOU,shape:[-1,13,13,5,2,1]'
best_ious = K.max(iou_scores, axis=4) '这里很有意思,若两个true_boxes落在同一个gird cell里,我只取iou最大的那一个, 因为best_iou这个值只关心在这个gird cell中最大的那个iou,不关心来自于哪个true_boxes。'
best_ious = K.expand_dims(best_ious) 'shape:[1,-1,13,13,5,1]'
object_detections = K.cast(best_ious > 0.6, K.dtype(best_ious))
'选出IOU大于0.6的,不关注其损失。cast()函数,第一个参数是bool值,dtype是int,就会转换成0,1'
no_object_weights = (no_object_scale * (1 - object_detections) *
(1 - detectors_mask))
no_objects_loss = no_object_weights * K.square(-pred_confidence)
if rescore_confidence:
objects_loss = (object_scale * detectors_mask *
K.square(best_ious - pred_confidence))
else:
objects_loss = (object_scale * detectors_mask *
K.square(1 - pred_confidence))
confidence_loss = objects_loss + no_objects_loss
'计算confidence_loss,no_objects_loss是计算background的误差, objects_loss是计算与true_box匹配的anchor_boxes的误差,相比较no_objects_loss更关注这部分误差,其修正系数为5'
matching_classes = K.cast(matching_true_boxes[..., 4], 'int32')
matching_classes = K.one_hot(matching_classes, num_classes)
classification_loss = (class_scale * detectors_mask *
K.square(matching_classes - pred_class_prob))
'计算classification_loss,20维向量的差'
matching_boxes = matching_true_boxes[..., 0:4]
coordinates_loss = (coordinates_scale * detectors_mask *
K.square(matching_boxes - pred_boxes))
'计算coordinates_loss, x,y都是offset的均方损失,w,h是取了对数的均方损失,与YOLOv1中的平方根的差的均方类似,效果比其略好一点'
confidence_loss_sum = K.sum(confidence_loss)
classification_loss_sum = K.sum(classification_loss)
coordinates_loss_sum = K.sum(coordinates_loss)
total_loss = 0.5 * (
confidence_loss_sum + classification_loss_sum + coordinates_loss_sum)
if print_loss:
total_loss = tf.Print(
total_loss, [
total_loss, confidence_loss_sum, classification_loss_sum,
coordinates_loss_sum
],
message='yolo_loss, conf_loss, class_loss, box_coord_loss:')
return total_loss