参考文章:https://www.jianshu.com/p/324af87a11a6
参考代码:https://github.com/Kongsea/FPN_TensorFlow
主要是以下四个网络的模型基础配置,包括numclass,batch_size,max_step等
这里以resnet-50 为例,看看其基本的模型配置参数。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
tf.app.flags.DEFINE_string(
'dataset_tfrecord',
'../data/tfrecords',
'tfrecord of fruits dataset'
)
tf.app.flags.DEFINE_integer(
'shortside_size',
600,
'the value of new height and new width, new_height = new_width'
)
###########################
# data batch
##########################
tf.app.flags.DEFINE_integer(
'num_classes',
20,
'num of classes'
)
tf.app.flags.DEFINE_integer(
'batch_size',
1, #64
'num of imgs in a batch'
)
###############################
# optimizer-- MomentumOptimizer
###############################
tf.app.flags.DEFINE_float(
'momentum',
0.9,
'accumulation = momentum * accumulation + gradient'
)
############################
# train
########################
tf.app.flags.DEFINE_integer(
'max_steps',
900000,
'max iterate steps'
)
tf.app.flags.DEFINE_string(
'pretrained_model_path',
'../data/pretrained_weights/resnet_50.ckpt',
'the path of pretrained weights'
)
tf.app.flags.DEFINE_float(
'weight_decay',
0.0001,
'weight_decay in regulation'
)
################################
# summary and save_weights_checkpoint
##################################
tf.app.flags.DEFINE_string(
'summary_path',
'../output/resnet_summary',
'the path of summary write to '
)
tf.app.flags.DEFINE_string(
'trained_checkpoint',
'../output/resnet_trained_weights',
'the path to save trained_weights'
)
FLAGS = tf.app.flags.FLAGS
主要是io文件夹下面的几个py文件,对数据进行处理。
image_preprocess.py 主要是对图片进行resize 和 left-right flip 处理。
convert_data_to_tfrecord.py 主要是将数据转换成tfrecord
read_tfrecord.py 就是读tfrecord
主要是做一些显示的工作。
help_utils.py 在图像上画框,并且展示框出来之后的图像,并且对tensor 进行输出打印。
tools.py 显示代码运行时的进度条。
主要包括网络模型,loss,iou,box,rpn,label 等的相关代码具体实现。
包括模型的训练,测试,载入模型等多个功能,本次详细的看代码,也是从这里开始看。
1.获取基础网络:
_, share_net = get_network_byname(net_name=cfgs.NET_NAME,
inputs=img_batch,
num_classes=None,
is_training=True,
output_stride=None,
global_pool=False,
spatial_squeeze=False)
根据网络的名称,获取基础的网络结构,一共有两个 resnet_v1_50 和 resnet_v1_101 ,直接在 libs.neywork 里面的 network_factory.py 可以看到这个函数。
def get_network_byname(net_name,
inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True):
if net_name == 'resnet_v1_50':
FLAGS = get_flags_byname(net_name)
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=FLAGS.weight_decay)):
logits, end_points = resnet_v1.resnet_v1_50(inputs=inputs,
num_classes=num_classes,
is_training=is_training,
global_pool=global_pool,
output_stride=output_stride,
spatial_squeeze=spatial_squeeze
)
return logits, end_points
if net_name == 'resnet_v1_101':
FLAGS = get_flags_byname(net_name)
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=FLAGS.weight_decay)):
logits, end_points = resnet_v1.resnet_v1_101(inputs=inputs,
num_classes=num_classes,
is_training=is_training,
global_pool=global_pool,
output_stride=output_stride,
spatial_squeeze=spatial_squeeze
)
return logits, end_points
2.构建rpn 网络
rpn = build_rpn.RPN(net_name=cfgs.NET_NAME,
inputs=img_batch,
gtboxes_and_label=None,
is_training=False,
share_head=cfgs.SHARE_HEAD,
share_net=share_net,
stride=cfgs.STRIDE,
anchor_ratios=cfgs.ANCHOR_RATIOS,
anchor_scales=cfgs.ANCHOR_SCALES,
scale_factors=cfgs.SCALE_FACTORS,
base_anchor_size_list=cfgs.BASE_ANCHOR_SIZE_LIST, # P2, P3, P4, P5, P6
level=cfgs.LEVEL,
top_k_nms=cfgs.RPN_TOP_K_NMS,
rpn_nms_iou_threshold=cfgs.RPN_NMS_IOU_THRESHOLD,
max_proposals_num=cfgs.MAX_PROPOSAL_NUM,
rpn_iou_positive_threshold=cfgs.RPN_IOU_POSITIVE_THRESHOLD,
rpn_iou_negative_threshold=cfgs.RPN_IOU_NEGATIVE_THRESHOLD,
rpn_mini_batch_size=cfgs.RPN_MINIBATCH_SIZE,
rpn_positives_ratio=cfgs.RPN_POSITIVE_RATE,
remove_outside_anchors=False, # whether remove anchors outside
rpn_weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME])
3. rpn 的proposal 的预测
# rpn predict proposals
rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals() # rpn_score shape: [300, ]
4. 构建 fast-rcnn 网络
fast_rcnn = build_fast_rcnn.FastRCNN(img_batch=img_batch,
feature_pyramid=rpn.feature_pyramid,
rpn_proposals_boxes=rpn_proposals_boxes,
rpn_proposals_scores=rpn_proposals_scores,
img_shape=tf.shape(img_batch),
roi_size=cfgs.ROI_SIZE,
scale_factors=cfgs.SCALE_FACTORS,
roi_pool_kernel_size=cfgs.ROI_POOL_KERNEL_SIZE,
gtboxes_and_label=None,
fast_rcnn_nms_iou_threshold=cfgs.FAST_RCNN_NMS_IOU_THRESHOLD,
fast_rcnn_maximum_boxes_per_img=100,
fast_rcnn_nms_max_boxes_per_class=cfgs.FAST_RCNN_NMS_MAX_BOXES_PER_CLASS,
show_detections_score_threshold=cfgs.FINAL_SCORE_THRESHOLD, # show detections which score >= 0.6
num_classes=cfgs.CLASS_NUM,
fast_rcnn_minibatch_size=cfgs.FAST_RCNN_MINIBATCH_SIZE,
fast_rcnn_positives_ratio=cfgs.FAST_RCNN_POSITIVE_RATE,
fast_rcnn_positives_iou_threshold=cfgs.FAST_RCNN_IOU_POSITIVE_THRESHOLD,
use_dropout=False,
weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME],
is_training=False,
level=cfgs.LEVEL)
5. fast-rcnn 进行预测
fast_rcnn_decode_boxes, fast_rcnn_score, num_of_objects, detection_category = \
fast_rcnn.fast_rcnn_predict()
6. 参数初始化,加载模型
# train
init_op = tf.group(
tf.global_variables_initializer(),
tf.local_variables_initializer()
)
restorer, restore_ckpt = restore_model.get_restorer(checkpoint_path=args.weights)