centernet(objects as points)的尝试[基于tf.slim]

最近拜读了《CenterNet :Objects as Points》(https://arxiv.org/pdf/1904.07850.pdf),感觉很受启发。网上对这篇论文解读的博客也是很多,有很多博主的总结也很到位。

其中我感觉下面这张图最能代表核心思想:

 

centernet(objects as points)的尝试[基于tf.slim]_第1张图片

网络的主要结构是通过一个backbone生成网络生成Feature map,再通过带分支的检测头分别预测物体的中心点,物体的宽高和heatmap。下面是自写的检测头代码:feat为某个backbone生成的多维特征图

def center_branch(feat,n_classes):
    #hm
    hm = slim.conv2d(feat,256,[3,3])
    hm = slim.conv2d(hm,n_classes,[1,1],activation_fn=tf.sigmoid)
    #WH
    wh = slim.conv2d(feat,256,[3,3])
    wh = slim.conv2d(wh,2,[1,1],activation_fn=None)
    #reg
    reg = slim.conv2d(feat,256,[3,3])
    reg = slim.conv2d(reg,2,[1,1],activation_fn=None)
    return {'hm':hm,'wh':wh,'reg':reg}

损失函数方面,wh和reg采用L1损失,hm采用改进后的focalloss

centernet(objects as points)的尝试[基于tf.slim]_第2张图片

自写损失函数代码如下:

def loss_total(y_true,output):
    '''
    y_true: N,128,128,24 (voc数据集为例,前20个维度为类别,后面为wh和reg偏差) 
    output: 每个batch各个损失值的均值

    '''
    # placeholder转换为字典方便运算
    hm = y_true[...,:20]
    wh = y_true[...,20:22]
    reg = y_true[...,22:24]
    y_true = {'hm':hm,'wh':wh,'reg':reg}

    mask = tf.equal(y_true['hm'],1)
    mask = tf.cast(mask,tf.float32) #[batch,128,128,20]
    mask_reg = tf.reduce_max(mask,axis = -1)
    mask_reg = tf.expand_dims(mask_reg,-1)
    mask_reg = tf.concat([mask_reg,mask_reg],-1) #[batch,128,128,2]
    N = tf.reduce_sum(mask,1)
    N = tf.reduce_sum(N,1)
    N = tf.reduce_sum(N,1) #[batch,-1]每张图有几个目标
    loss_hm_pos = -1.0*tf.pow(1.-output['hm'],2.)*tf.log(output['hm']+1e-12) * mask
    loss_hm_neg = -1.0*tf.pow(1.-y_true['hm'],4)*tf.pow(output['hm'],2)*tf.log(1.-output['hm']+1e-12)*(1.-mask)
    loss_hm = tf.reduce_sum(loss_hm_pos+loss_hm_neg,axis=1)
    loss_hm = tf.reduce_sum(loss_hm,axis=1)
    loss_hm = tf.reduce_sum(loss_hm,axis=1)/N
    loss_wh = tf.abs(y_true['wh']-output['wh']) * mask_reg
    loss_wh = tf.reduce_sum(loss_wh,axis=1)
    loss_wh = tf.reduce_sum(loss_wh,axis=1)
    loss_wh = tf.reduce_sum(loss_wh,axis=1)/N
    loss_reg = tf.abs(y_true['reg']-output['reg']) * mask_reg
    loss_reg = tf.reduce_sum(loss_reg,axis=1)
    loss_reg = tf.reduce_sum(loss_reg,axis=1)
    loss_reg = tf.reduce_sum(loss_reg,axis=1)/N
    loss_total = tf.reduce_mean(loss_hm+0.1*loss_wh+loss_reg)
    return loss_total,tf.reduce_mean(loss_hm),tf.reduce_mean(loss_wh),tf.reduce_mean(loss_reg)

单张图像label编码示例

def encode_label(label,n_classes,h,w):
    '''
    label:"图像path x1,y1,x2,y2,cls x1,y1,x2,y2,cls... "
    eg."/VOCdevkit/VOC2007/JPEGImages/000006.jpg 187,135,282,242,15 154,209,369,375,10 255,207,366,375,8 138,211,249,375,8"
    n_classes:类别数
    h,w:原图高宽
    '''
    label = label.split( )
    # 创建存放矩阵
    hm = np.zeros((128,128,n_classes),dtype=np.float32)
    wh = np.zeros((128,128,2),dtype=np.float32)
    reg = np.zeros((128,128,2),dtype=np.float32)
    for obj in label[1:]:
        obj = np.asarray(obj.split(','),dtype=np.float32)
        # 获取目标中心点对应到[0,128]尺度
        ct_item = np.asarray((obj[2:4]+obj[:2])/2.,dtype=np.float32)/np.asarray([w,h],dtype=np.float32)*128
        wh_item = np.asarray(obj[2:4]-obj[:2],dtype=np.float32)/np.asarray([w,h],dtype=np.float32)*128
        # 获得类别
        cat_item  =np.asarray(obj[4],dtype=np.int32)
        # 根据中心点位置存入reg和wh的值
        x, y = int(ct_item[0]),int(ct_item[1])
        reg[y,x,:]=[ct_item[1]-y,ct_item[0]-x]
        wh[y,x,:]=[wh_item[1],wh_item[0]]
        # 绘制目标中心的高斯图像
        radius = gaussian_radius(wh_item) # 计算高斯半径
        radius = max(0, int(radius))
        hm_slice = hm[:,:,cat_item] # 抽出对应类别热力图 
        hm_slice = draw_umich_gaussian(hm_slice,(x,y),radius,k=1) # 把高斯圆绘制到热力图上
        hm[:,:,cat_item] = hm_slice # 放回
    #y_true = {'hm':hm,'wh':wh,'reg':reg}
    y_true = np.concatenate((hm,wh,reg),axis=2)
    return y_true

gaussion_radius和draw_umich_gaussian函数参考官方源码

https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py

 

实验:

backbone采用unet++(原论文采用的是Hourglass),采用voc2007测试集训练效果如下(训练了10几个小时。。。):

tensorboard上显示效果:(左边为groundtruth,右边为预测结果(没有NMS))

centernet(objects as points)的尝试[基于tf.slim]_第3张图片

(训练部分写的太锉,就不往上放了。。。。)

 

你可能感兴趣的:(图像处理,深度学习)