tensorflow-FPN 代码详解

参考文章:https://www.jianshu.com/p/324af87a11a6

参考代码:https://github.com/Kongsea/FPN_TensorFlow

本文对源码的每个文件进行详细的解读!

 

tensorflow-FPN 代码详解_第1张图片

  • configs:下面是一些模型配置的超参数,这里有vgg,inception等。
  • data: 使用来做数据的工厂文件,这里的文件与数据生成有关。
  • help_utils:有两个文件,help_utils.py是show图片的一个重要文件。
  • scripts: 脚本文件,在Ubuntu下直接执行的.sh文件,调用tools文件进行train,test,eval,inferen。
  • tools:目标检测的几个阶段的主函数文件,训练,测试,评估,推断。接下来的讲解路线就是从这里开始。
  • gen_classes.py: 作者提到这是用来生成label的文件,可以得到txt,txt是制作标签时的第一步嘛。
  • cnvert_txt2xml.py: 从名字就可以看出来是把txt文件转化为xml的标注格式。

1. configs 文件

主要是以下四个网络的模型基础配置,包括numclass,batch_size,max_step等

tensorflow-FPN 代码详解_第2张图片

这里以resnet-50 为例,看看其基本的模型配置参数。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

tf.app.flags.DEFINE_string(
    'dataset_tfrecord',
    '../data/tfrecords',
    'tfrecord of fruits dataset'
)
tf.app.flags.DEFINE_integer(
    'shortside_size',
    600,
    'the value of new height and new width, new_height = new_width'
)

###########################
#  data batch
##########################
tf.app.flags.DEFINE_integer(
    'num_classes',
    20,
    'num of classes'
)
tf.app.flags.DEFINE_integer(
    'batch_size',
    1, #64
    'num of imgs in a batch'
)

###############################
# optimizer-- MomentumOptimizer
###############################
tf.app.flags.DEFINE_float(
    'momentum',
    0.9,
    'accumulation = momentum * accumulation + gradient'
)

############################
#  train
########################
tf.app.flags.DEFINE_integer(
    'max_steps',
    900000,
    'max iterate steps'
)

tf.app.flags.DEFINE_string(
    'pretrained_model_path',
    '../data/pretrained_weights/resnet_50.ckpt',
    'the path of pretrained weights'
)
tf.app.flags.DEFINE_float(
    'weight_decay',
    0.0001,
    'weight_decay in regulation'
)
################################
# summary and save_weights_checkpoint
##################################
tf.app.flags.DEFINE_string(
    'summary_path',
    '../output/resnet_summary',
    'the path of summary write to '
)
tf.app.flags.DEFINE_string(
    'trained_checkpoint',
    '../output/resnet_trained_weights',
    'the path to save trained_weights'
)
FLAGS = tf.app.flags.FLAGS

2. data文件

主要是io文件夹下面的几个py文件,对数据进行处理。

image_preprocess.py  主要是对图片进行resize 和 left-right flip 处理。

convert_data_to_tfrecord.py 主要是将数据转换成tfrecord

read_tfrecord.py 就是读tfrecord

tensorflow-FPN 代码详解_第3张图片

3. help_utils

主要是做一些显示的工作。

help_utils.py 在图像上画框,并且展示框出来之后的图像,并且对tensor 进行输出打印。

tools.py 显示代码运行时的进度条。

tensorflow-FPN 代码详解_第4张图片

4. libs

主要包括网络模型,loss,iou,box,rpn,label 等的相关代码具体实现。

tensorflow-FPN 代码详解_第5张图片

5.tools

包括模型的训练,测试,载入模型等多个功能,本次详细的看代码,也是从这里开始看。

tensorflow-FPN 代码详解_第6张图片

 

网络的主要工作流程:

  • 1. 参数传入,数据读取
  • 2. 构建基础网络,vgg,inception,提供每层的feature map。FPN不仅仅使用最后一层,中间的几层也会使用的。
  • 3. 构建RPN网络,对每层的feature map进行生成box和背景二分类
  • 4. 构建Faster RCNN网络
  • 5. 进行计算。
  • 6. 使用help_utils进行结果展示

test.py

1.获取基础网络:

 _, share_net = get_network_byname(net_name=cfgs.NET_NAME,
                                      inputs=img_batch,
                                      num_classes=None,
                                      is_training=True,
                                      output_stride=None,
                                      global_pool=False,
                                      spatial_squeeze=False)

根据网络的名称,获取基础的网络结构,一共有两个 resnet_v1_50 和 resnet_v1_101 ,直接在 libs.neywork 里面的 network_factory.py 可以看到这个函数。

def get_network_byname(net_name,
                       inputs,
                       num_classes=None,
                       is_training=True,
                       global_pool=True,
                       output_stride=None,
                       spatial_squeeze=True):
  if net_name == 'resnet_v1_50':
    FLAGS = get_flags_byname(net_name)
    with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=FLAGS.weight_decay)):
      logits, end_points = resnet_v1.resnet_v1_50(inputs=inputs,
                                                  num_classes=num_classes,
                                                  is_training=is_training,
                                                  global_pool=global_pool,
                                                  output_stride=output_stride,
                                                  spatial_squeeze=spatial_squeeze
                                                  )

    return logits, end_points
  if net_name == 'resnet_v1_101':
    FLAGS = get_flags_byname(net_name)
    with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=FLAGS.weight_decay)):
      logits, end_points = resnet_v1.resnet_v1_101(inputs=inputs,
                                                   num_classes=num_classes,
                                                   is_training=is_training,
                                                   global_pool=global_pool,
                                                   output_stride=output_stride,
                                                   spatial_squeeze=spatial_squeeze
                                                   )
    return logits, end_points

2.构建rpn 网络

 rpn = build_rpn.RPN(net_name=cfgs.NET_NAME,
                        inputs=img_batch,
                        gtboxes_and_label=None,
                        is_training=False,
                        share_head=cfgs.SHARE_HEAD,
                        share_net=share_net,
                        stride=cfgs.STRIDE,
                        anchor_ratios=cfgs.ANCHOR_RATIOS,
                        anchor_scales=cfgs.ANCHOR_SCALES,
                        scale_factors=cfgs.SCALE_FACTORS,
                        base_anchor_size_list=cfgs.BASE_ANCHOR_SIZE_LIST,  # P2, P3, P4, P5, P6
                        level=cfgs.LEVEL,
                        top_k_nms=cfgs.RPN_TOP_K_NMS,
                        rpn_nms_iou_threshold=cfgs.RPN_NMS_IOU_THRESHOLD,
                        max_proposals_num=cfgs.MAX_PROPOSAL_NUM,
                        rpn_iou_positive_threshold=cfgs.RPN_IOU_POSITIVE_THRESHOLD,
                        rpn_iou_negative_threshold=cfgs.RPN_IOU_NEGATIVE_THRESHOLD,
                        rpn_mini_batch_size=cfgs.RPN_MINIBATCH_SIZE,
                        rpn_positives_ratio=cfgs.RPN_POSITIVE_RATE,
                        remove_outside_anchors=False,  # whether remove anchors outside
                        rpn_weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME])

3. rpn 的proposal 的预测

 # rpn predict proposals
    rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals()  # rpn_score shape: [300, ]

4. 构建 fast-rcnn 网络

fast_rcnn = build_fast_rcnn.FastRCNN(img_batch=img_batch,
                                         feature_pyramid=rpn.feature_pyramid,
                                         rpn_proposals_boxes=rpn_proposals_boxes,
                                         rpn_proposals_scores=rpn_proposals_scores,
                                         img_shape=tf.shape(img_batch),
                                         roi_size=cfgs.ROI_SIZE,
                                         scale_factors=cfgs.SCALE_FACTORS,
                                         roi_pool_kernel_size=cfgs.ROI_POOL_KERNEL_SIZE,
                                         gtboxes_and_label=None,
                                         fast_rcnn_nms_iou_threshold=cfgs.FAST_RCNN_NMS_IOU_THRESHOLD,
                                         fast_rcnn_maximum_boxes_per_img=100,
                                         fast_rcnn_nms_max_boxes_per_class=cfgs.FAST_RCNN_NMS_MAX_BOXES_PER_CLASS,
                                         show_detections_score_threshold=cfgs.FINAL_SCORE_THRESHOLD,  # show detections which score >= 0.6
                                         num_classes=cfgs.CLASS_NUM,
                                         fast_rcnn_minibatch_size=cfgs.FAST_RCNN_MINIBATCH_SIZE,
                                         fast_rcnn_positives_ratio=cfgs.FAST_RCNN_POSITIVE_RATE,
                                         fast_rcnn_positives_iou_threshold=cfgs.FAST_RCNN_IOU_POSITIVE_THRESHOLD,
                                         use_dropout=False,
                                         weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME],
                                         is_training=False,
                                         level=cfgs.LEVEL)

5. fast-rcnn 进行预测

  fast_rcnn_decode_boxes, fast_rcnn_score, num_of_objects, detection_category = \
        fast_rcnn.fast_rcnn_predict()

6. 参数初始化,加载模型

 # train
    init_op = tf.group(
        tf.global_variables_initializer(),
        tf.local_variables_initializer()
    )

    restorer, restore_ckpt = restore_model.get_restorer(checkpoint_path=args.weights)

 

你可能感兴趣的:(深度学习)