昨天看下Mask-rcnn的keras代码,Github上start最多的那个。由于代码量比较多,所以需要梳理下整个流程。今天用visio简单绘制下整个数据流程图,方便理解整个算法。里面的知识点还是比较多的,所以搞清楚一个算法的细节,需要我们认真阅读下源码,并思考为什么这样做。只要能够掌握好细节,我们才可以对算法部分进行改进。
注:黑色是第一阶段,也就是RPN阶段。红色是第二阶段,也就是使用RPN的输出进行分类,框回归以及分割。金字塔特征就是p2,p3,p4,p5,p6。
输入张量:
input_image, 输入图像 shape(None,h,w,n) 其中h、w必须是2^6=64的倍数
input_image_meta= np.array(
[image_id] + # size=1
list(original_image_shape) + # size=3
list(image_shape) + # size=3
list(window) + # size=4 (y1, x1, y2, x2) in image cooredinates
[scale] + # size=1
list(active_class_ids) # size=num_classes
)
input_rpn_match,RPN网络的输入,取值-1,0,1 shape(None,None,1)
input_rpn_bbox,RPN网络的输入,输入框 shape(None,None,4)
input_gt_class_ids,检测网络的输入,真实类别 shape(None,None)
input_gt_boxes,检测网络的输入,真实框 shape(None,None,4) 需要进行归一化
input_gt_masks,检测网络的输入,真实掩码 shape(None,mask_h,mask_w,MAX_INSTANCE)
输出张量:
rpn_class_logits,经过RPN网络的输出,shape(None,all_h*all_w*ratio_l,2) all_h,all_w是p2p3p4p5p6的特征图高和宽之和
rpn_class, 经过RPN玩网络的输出 ,shape(None,all_h*all_w*ratio_l,2) ratio_l代表 每个特征图中RPN_ANCHOR_RATIOS的长度
rpn_bbox, 经过RPN玩网络的输出 ,shape(None,all_h*all_w*ratio_l,4) 4代表dy,dx,log(dh),log(dw)
mrcnn_class_logits,相继经过NMS的ProposalLayer和Align_pooling后送入检测网络的分类liner shape(None,TRAIN_ROIS_PER_IMAGE,NUM_CLASS)
mrcnn_class,相继经过NMS的ProposalLayer和Align_pooling后送入检测网络的分类softmax shape(None,TRAIN_ROIS_PER_IMAGE,NUM_CLASS)
mrcnn_bbox,相继经过NMS的ProposalLayer和Align_pooling后送入检测网络的分类坐标dy,dx,log(dh),log(dw) shape(None,TRAIN_ROIS_PER_IMAGE,4)
mrcnn_mask,相继经过NMS的ProposalLayer和Align_pooling后送入掩码网络的掩码信息 shape(None,TRAIN_ROIS_PER_IMAGE,mask_h,mask_w,NUM_CLASS)
rpn_rois, 经过RPN网络后进行NMS的ProposalLayer层,shape(None,proposals_roi_num,(y1,x1,y2,x2))
output_rois, 就是中间变量rois shape(None,TRAIN_ROIS_PER_IMAGE,4)
rpn_class_loss, 由input_rpn_match, rpn_class_logits计算出来的损失
rpn_bbox_loss, 由input_rpn_bbox, input_rpn_match, rpn_bbox计算出来的损失
class_loss, 由target_class_ids, mrcnn_class_logits, active_class_ids计算出来 target_class_ids由input_image_meta得来
bbox_loss, 由target_bbox, target_class_ids, mrcnn_bbox计算出来
mask_loss, 由target_mask, target_class_ids, mrcnn_mask计算出来
中间张量:经过DetectionTargetLayer层的结果
rois, shape(None,TRAIN_ROIS_PER_IMAGE,4)
target_class_ids, shape(None,TRAIN_ROIS_PER_IMAGE)
target_bbox, shape(None,TRAIN_ROIS_PER_IMAGE,4) deta
target_mask shape(None,TRAIN_ROIS_PER_IMAGE,mask_h,mask_w)