Faster R-CNN源码阅读之九:Faster R-CNN/tools/train_net.py

  1. Faster R-CNN源码阅读之零:写在前面
  2. Faster R-CNN源码阅读之一:Faster R-CNN/lib/networks/network.py
  3. Faster R-CNN源码阅读之二:Faster R-CNN/lib/networks/factory.py
  4. Faster R-CNN源码阅读之三:Faster R-CNN/lib/networks/VGGnet_test.py
  5. Faster R-CNN源码阅读之四:Faster R-CNN/lib/rpn_msr/generate_anchors.py
  6. Faster R-CNN源码阅读之五:Faster R-CNN/lib/rpn_msr/proposal_layer_tf.py
  7. Faster R-CNN源码阅读之六:Faster R-CNN/lib/fast_rcnn/bbox_transform.py
  8. Faster R-CNN源码阅读之七:Faster R-CNN/lib/rpn_msr/anchor_target_layer_tf.py
  9. Faster R-CNN源码阅读之八:Faster R-CNN/lib/rpn_msr/proposal_target_layer_tf.py
  10. Faster R-CNN源码阅读之九:Faster R-CNN/tools/train_net.py
  11. Faster R-CNN源码阅读之十:Faster R-CNN/lib/fast_rcnn/train.py
  12. Faster R-CNN源码阅读之十一:Faster R-CNN预测demo代码补完
  13. Faster R-CNN源码阅读之十二:写在最后

一、介绍
   本demo由Faster R-CNN官方提供,我只是在官方的代码上增加了注释,一方面方便我自己学习,另一方面贴出来和大家一起交流。
   该文件中的函数的训练Faster RCNN网络的主入口,并通过命令行等传入以下必要的配置信息,然后开始训练网络。

二、代码以及注释

#!/usr/bin/env python
# coding=utf-8

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Train a Fast R-CNN network on a region of interest database."""

import _init_paths
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from datasets.factory import get_imdb
from networks.factory import get_network
import argparse
import pprint
import numpy as np
import sys
import pdb

def parse_args():
    """
    Parse input arguments
    配置传入的参数变量
    """
    parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
    # --device代表是使用gpu还是cpu,默认为cpu
    parser.add_argument('--device', dest='device', help='device to use',
                        default='cpu', type=str)

    # --device_id代表设备(cpu或者gpu)编号
    parser.add_argument('--device_id', dest='device_id', help='device id to use',
                        default=0, type=int)
    # --solver代表模型的配置文件
    parser.add_argument('--solver', dest='solver',
                        help='solver prototxt',
                        default=None, type=str)
    # --iters代表最大的循环迭代次数,默认为70000次
    parser.add_argument('--iters', dest='max_iters',
                        help='number of iterations to train',
                        default=70000, type=int)
    # --weights代表预训练的权重文件路径
    parser.add_argument('--weights', dest='pretrained_model',
                        help='initialize with pretrained model weights',
                        default=None, type=str)
    # --cfg代表可选的配置文件
    parser.add_argument('--cfg', dest='cfg_file',
                        help='optional config file',
                        default=None, type=str)
    # --imdb代表训练的数据集
    parser.add_argument('--imdb', dest='imdb_name',
                        help='dataset to train on',
                        default='kitti_train', type=str)
    # --rand代表是否使用不同的随机数种子生成随机数
    parser.add_argument('--rand', dest='randomize',
                        help='randomize (do not use a fixed seed)',
                        action='store_true')
    # --network代表网络名称,一般具有固定的形式。常以'_train'结尾。
    parser.add_argument('--network', dest='network_name',
                        help='name of the network',
                        default=None, type=str)
    # --set的功能见下
    parser.add_argument('--set', dest='set_cfgs',
                        help='set config keys', default=None,
                        nargs=argparse.REMAINDER)

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    print('Called with args:')
    print(args)

    if args.cfg_file is not None:
        # Load a config file and merge it into the default options.
        # 从config文件中加载配置信息,并添加到默认的选项中
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        # Set config keys via list (e.g., from command line).
        # 通过list设置配置信息
        cfg_from_list(args.set_cfgs)

    print('Using config:')
    pprint.pprint(cfg)

    # 设置numpy的随机数种子
    if not args.randomize:
        # fix the random seeds (numpy and caffe) for reproducibility
        np.random.seed(cfg.RNG_SEED)
    # 根据使用的图片数据集名称获取数据集
    imdb = get_imdb(args.imdb_name)
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    # 将训练数据变成minibatch的形式
    roidb = get_training_roidb(imdb)

    # 设置网络权重文件的保存目录
    output_dir = get_output_dir(imdb, None)
    print 'Output will be saved to `{:s}`'.format(output_dir)

    # 设置device name
    device_name = '/{}:{:d}'.format(args.device,args.device_id)
    print device_name

    # 根据network name建立网络结构
    network = get_network(args.network_name)
    print 'Use network `{:s}` in training'.format(args.network_name)

    # 训练网络
    train_net(network, imdb, roidb, output_dir,
              pretrained_model=args.pretrained_model,
              max_iters=args.max_iters)

你可能感兴趣的:(Faster,RCNN,源码阅读,深度学习,Tensorflow,Faster,RCNN)