Keras 版 YOLO v3 中的loss function 代码解读

参考:https://blog.csdn.net/summer_xialuo/article/details/82794076?utm_source=blogxgwz9

个人觉得,整个检测最重要的就是darknet网络与损失函数的设计。

def yolo_loss(args, anchors, num_classes, ignore_thresh=.5, print_loss=False):
    #输入:
    #args:是一个list 组合,包括了预测值和真实值,具体如下:
    #     args[:num_layers]--预测值yolo_outputs,
    #     args[num_layers:]--真实值y_true,
    #     yolo_outputs,y_true同样是list,分别是[y1,y2,y3]三个feature map 上的的预测结果,
    #     每个y都是m*grid*grid*num_anchors*(num_classes+5),作者原文是三层,分别是(13,13,3,25)\
    #     (26,26,3,25),(52,52,3,25)
    #anchors:输入预先选择好的anchors box,原文是9个box,三层feature map 各三个。
    #num_classes:原文分了20类  
    #ignore_thresh=.5:如果一个default box与true box的IOU 小于ignore_thresh, 
    #                 则作为负样本confidence 损失。
    #print_loss:loss的打印开关。
    #输出:一维张量。
    #    
    '''Return yolo_loss tensor
    Parameters
    ----------
    yolo_outputs: list of tensor, the output of yolo_body or tiny_yolo_body
    y_true: list of array, the output of preprocess_true_boxes
    anchors: array, shape=(N, 2), wh
    num_classes: integer
    ignore_thresh: float, the iou threshold whether to ignore object confidence loss
    Returns
    -------
    loss: tensor, shape=(1,)
    '''
    num_layers = len(anchors)//3 # default setting
    #args即[*model_body.output, *y_true]
    #model_body.output = [y1,y2,y3]即三个尺度的预测结果,每个y都是m*grid*grid*num_anchors*(num_classes+5)
    #m = batch_size
    yolo_outputs = args[:num_layers]
    y_true = args[num_layers:]
    anchor_mask = [[6,7,8], [3,4,5], [0,1,2]] if num_layers==3 else [[3,4,5], [1,2,3]]
    input_shape = K.cast(K.shape(yolo_outputs[0])[1:3] * 32, K.dtype(y_true[0]))#得到(416*416)
    grid_shapes = [K.cast(K.shape(yolo_outputs[l])[1:3], K.dtype(y_true[0])) for l in range(num_layers)]
    #得到三个grid的大小
    loss = 0
    m = K.shape(yolo_outputs[0])[0] # batch size, tensor
    mf = K.cast(m, K.dtype(yolo_outputs[0]))
 
    for l in range(num_layers):
        object_mask = y_true[l][..., 4:5]#置信率
        true_class_probs = y_true[l][..., 5:]#分类
 
        #将网络最后一层输出转化为BBOX的参数
        #anchors[anchor_mask[l]]:anchors对应的某一个尺度的anchor
        #例:最小尺度预测大物体:
        '''
        anchors[anchor_mask[0]]
        [[116  90]
        [156 198]
        [373 326]]
        '''
        grid, raw_pred, pred_xy, pred_wh = yolo_head(yolo_outputs[l],
             anchors[anchor_mask[l]], num_classes, input_shape, calc_loss=True)
        pred_box = K.concatenate([pred_xy, pred_wh])#相对于gird的box参数(x,y,w,h)
 
        # Darknet raw box to calculate loss.
        #这是对x,y,w,b转换公式的反变换,转换的true_box
        raw_true_xy = y_true[l][..., :2]*grid_shapes[l][::-1] - grid#保存时其实保存的是5个数(:2)就是x,y
        raw_true_wh = K.log(y_true[l][..., 2:4] / anchors[anchor_mask[l]] * input_shape[::-1])
        
        #这部操作是避免出现log(0) = 负无穷,故当object_mask置信率接近0是返回全0结果
        #K.switch(条件函数,返回值1,返回值2)其中1,2要等shape
        raw_true_wh = K.switch(object_mask, raw_true_wh, K.zeros_like(raw_true_wh)) # avoid log(0)=-inf
        box_loss_scale = 2 - y_true[l][...,2:3]*y_true[l][...,3:4]#这应该是个什么面积
 
        # Find ignore mask, iterate over each of batch.
        #tf.TensorArray--相当于建立一个动态数组,size=1,是二维动态数组。
        ignore_mask = tf.TensorArray(K.dtype(y_true[0]), size=1, dynamic_size=True)
        object_mask_bool = K.cast(object_mask, 'bool')#将真实标定的数据置信率转换为T or F的掩膜
        def loop_body(b, ignore_mask):
            #object_mask_bool(b,13,13,3,4)--五维数组,第b张图的第l层feature map.
            #true_box将第b图第i层feature map,有目标窗口的坐标位置取出来。true_box[x,y,w,h],
            true_box = tf.boolean_mask(y_true[l][b,...,0:4], object_mask_bool[b,...,0])#b是第几张图,将置信率为0的其他参数清
            iou = box_iou(pred_box[b], true_box)#单张图片单个尺度算iou,即该层所有预测窗口
            #pred_box(13,13,3,4)与真实窗口true_box(设有j个)之间的IOU,输出为iou(13,13,3,j)
            best_iou = K.max(iou, axis=-1)#先取每个grid上多个anchor box上的最大的iou
            #best_iou(13,13,3)值是最大的iou
            ignore_mask = ignore_mask.write(b, K.cast(best_iou

注意,这里的负样本选择,

1、不同Faster RCNN,设计了threashold 为0.3,以控制正负样本的比例。

2、也不同于SSD中用top k来控制负样本数量。

原文中,作者也尝试了像Faster RCNN一样,设计了0~0.3作为负样本,0.7~1.0作为正样本,但效果不好。

对此,不是很理解。

你可能感兴趣的:(Keras 版 YOLO v3 中的loss function 代码解读)