Fast-RCNN解析:训练阶段代码导读

转载自:

Fast-RCNN解析:训练阶段代码导读 - LinJM-机器视觉 - 博客频道 - CSDN.NET
http://blog.csdn.net/linj_m/article/details/48930179#0-tsina-1-35514-397232819ff9a47a7b7e80a40613cfe1


关于Fast-RCNN的解析,我们将主要分为两个部分来介绍,其中一个是训练部分,这个部分非常重要,是我们需要重点讲解的;另一个是测试部分,这个部分关系到具体的应用,所以也是必须要了解的。本篇博文中,我们先从训练部分讲起。

训练阶段流程

在官方文档中,训练阶段的启动脚本如下所示:

<code class="hljs avrasm has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">./tools/train_net<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.py</span> --gpu <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> --solver models/VGG16/solver<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.prototxt</span> \
    --weights data/imagenet_models/VGG16<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.v</span>2<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.caffemodel</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li></ul>

从这段脚本中,我们可以知道,训练的入口函数就在train_net.py中,其位于fast-rcnn/tools/文件夹内,我们先来看看这个文件。

<code class="language-python hljs  has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> __name__ == <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'__main__'</span>:
    args = parse_args()
    print(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Called with args:'</span>)
    print(args)

    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> args.cfg_file <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">is</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>:
        cfg_from_file(args.cfg_file)
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> args.set_cfgs <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">is</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>:
        cfg_from_list(args.set_cfgs)

    print(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Using config:'</span>)
    pprint.pprint(cfg)

    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> args.randomize:
        <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># fix the random seeds (numpy and caffe) for reproducibility</span>
        np.random.seed(cfg.RNG_SEED)
        caffe.set_random_seed(cfg.RNG_SEED)

    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># set up caffe</span>
    caffe.set_mode_gpu()
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> args.gpu_id <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">is</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>:
        caffe.set_device(args.gpu_id)

    imdb = get_imdb(args.imdb_name)
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Loaded dataset `{:s}` for training'</span>.format(imdb.name)
    roidb = get_training_roidb(imdb)

    output_dir = get_output_dir(imdb, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>)
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Output will be saved to `{:s}`'</span>.format(output_dir)

    train_net(args.solver, roidb, output_dir,
              pretrained_model=args.pretrained_model,
              max_iters=args.max_iters)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li></ul>

从以上的code,我们可以看到,train_net.py的主要处理过程包括以下三个部分:

(1) 首先对启动脚本的输入参数进行处理,是通过如下这个函数parse_args()进行处理的。

<code class="language-python hljs  has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">parse_args</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">()</span>:</span>
    <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""
    Parse input arguments
    """</span>
    parser = argparse.ArgumentParser(description=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Train a Fast R-CNN network'</span>)
    parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--gpu'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'gpu_id'</span>,
                        help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'GPU device id to use [0]'</span>, default=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>, type=int)
    parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--solver'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'solver'</span>,
                        help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'solver prototxt'</span>, default=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>, type=str)
    parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--iters'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'max_iters'</span>,
                        help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'number of iterations to train'</span>,default=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">40000</span>, type=int)
    parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--weights'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'pretrained_model'</span>,
                        help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'initialize with pretrained model weights'</span>, default=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>, type=str)
    parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--cfg'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'cfg_file'</span>,
                        help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'optional config file'</span>,default=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>, type=str)
    parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--imdb'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'imdb_name'</span>,
                        help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'dataset to train on'</span>,default=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'voc_2007_trainval'</span>, type=str)
    parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--rand'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'randomize'</span>,
                        help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'randomize (do not use a fixed seed)'</span>,action=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'store_true'</span>)
    parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--set'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'set_cfgs'</span>,
                        help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'set config keys'</span>, default=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>,nargs=argparse.REMAINDER)

    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> len(sys.argv) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:
        parser.print_help()
        sys.exit(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)

    args = parser.parse_args()
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> args</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li></ul>

从这个函数中,我们可以了解到,训练脚本的可选输入参数包括:

  • –gpu: 这个参数指定训练使用的GPU设备,我的电脑只有一枚GPU,默认情况下自动开启,其gpu_id为0;
  • –solver: 这个参数指定网络的优化方法,并在其solver的prototxt指向了定义网络结构的文件(train.prototxt);
  • –weights: 这个参数指定了finetune的初始参数,我的电脑GPU不怎么高端,只能使用caffenet进行finetune;
  • –imdb: 这个参数指定了训练所需要的训练数据,如果你需要训练自己的数据,那么这个参数是必须要指定的;

(2) 然后是根据输入的参数(–imdb 参数后面指定的数据)来准备训练样本,这个步骤涉及到两个函数:一个 imdb=get_imdb(args.imdb_name) , 另一个是roidb=get_training_roidb(imdb)。关于这两个函数我们下部分会花大时间来解析,这里先不谈。

(3) 最后就是训练函数train_net(args.solver,roidb, output_dir, pretrained_model= args.pretrained_model, max_iters= args.max_iters)

而这个 train_net() 函数是从 fast_rcnn/lib/fast_rcnn 文件夹中的 train.py 中 import 进来的。那么接下来,我们来看看这个train.py

Fast-RCNN解析:训练阶段代码导读_第1张图片

这个函数主要由一个类SolverWrapper和两个函数get_training_roidb()和train_net()组成。 
首先,我们来看看train_net()函数:

<code class="language-python hljs  has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">train_net</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(solver_prototxt, roidb, output_dir,
              pretrained_model=None, max_iters=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">40000</span>)</span>:</span>
    <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""Train a Fast R-CNN network."""</span>
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                       pretrained_model=pretrained_model)

    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Solving...'</span>
    sw.train_model(max_iters)
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'done solving'</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li></ul>

可以发现,该函数是通过调用类SolverWrapper来实现其主要功能的,因此,我们跟进到类SolverWrapper的类构造函数中去:

<code class="language-python hljs  has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">__init__</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, solver_prototxt, roidb, output_dir,
                 pretrained_model=None)</span>:</span>
        <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""Initialize the SolverWrapper."""</span>
        self.output_dir = output_dir

        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Computing bounding-box regression targets...'</span>
        self.bbox_means, self.bbox_stds = \
                rdl_roidb.add_bbox_regression_targets(roidb)
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'done'</span>

        self.solver = caffe.SGDSolver(solver_prototxt)
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> pretrained_model <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">is</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>:
            <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> (<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Loading pretrained model '</span>
                   <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'weights from {:s}'</span>).format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)

        self.solver_param = caffe_pb2.SolverParameter()
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(solver_prototxt, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'rt'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> f:
            pb2.text_format.Merge(f.read(), self.solver_param)

        self.solver.net.layers[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>].set_roidb(roidb)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li></ul>

初始化完成后,就是要调用train_model函数来进行网络训练,我们来看一下它的主体部分:

<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">train_model</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, max_iters)</span>:</span>
    <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""Network training loop."""</span>
    last_snapshot_iter = -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>
    timer = Timer()
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">while</span> self.solver.iter < max_iters:
        <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Make one SGD update</span>
        timer.tic()
        self.solver.step(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)
        timer.toc()
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self.solver.iter % (<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span> * self.solver_param.display) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>:
            <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'speed: {:.3f}s / iter'</span>.format(timer.average_time)

        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>:
            last_snapshot_iter = self.solver.iter
            self.snapshot()

    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> last_snapshot_iter != self.solver.iter:
        self.snapshot()</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li></ul>

到此为止,网络就可以开始训练了。

训练数据处理

不过,关于Fast-RCNN的重头戏我们其实还没开始——那就是如何准备训练数据。

在上面介绍训练的流程中,与此相关的函数是:imdb= get_imdb(args.imdb_name)

这个函数是从从lib/datasets/文件夹中的factory.py中import进来的,我们来看一下这个函数:

<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">get_imdb</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(name)</span>:</span>
    <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""Get an imdb (image database) by name."""</span>
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> __sets.has_key(name):
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">raise</span> KeyError(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Unknown dataset: {}'</span>.format(name))
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> __sets[name]()</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul>

这个函数很简单,其实就是根据字典的key来取得训练数据。 
那么这个字典是怎么形成的呢?看下面:

<code class="hljs livecodeserver has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">inria_devkit_path = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'/home/jeremy/jWork/frcn/fast-rcnn/data/INRIA/'</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span> <span class="hljs-operator" style="box-sizing: border-box;">in</span> [<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'train'</span>, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'test'</span>]:
    name = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'{}_{}'</span>.<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">format</span>(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'inria'</span>, <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span>)
    __sets[name] = (lambda <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span>=<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span>: datasets.inria(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span>, inria_devkit_path))</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li></ul>

它本质上是通过lib/datasets/文件夹下面的inria.py引入的。 
所以,现在我们就得开始进入inria.py(这个函数需要我们自己编写,可以参考pascal_voc.py编写)。

Fast-RCNN解析:训练阶段代码导读_第2张图片

首先,我们来看看类inria的构造函数:

<code class="language-python hljs  has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">__init__</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, image_set, devkit_path)</span>:</span>
        datasets.imdb.__init__(self, image_set)
        self._image_set = image_set
        self._devkit_path = devkit_path
        self._data_path = os.path.join(self._devkit_path, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'data'</span>)
        self._classes = (<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'__background__'</span>, <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># always index 0</span>
                         <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'1001'</span>)
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = [<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'.jpg'</span>, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'.png'</span>]
        self._image_index = self._load_image_set_index()
        <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Default to roidb handler</span>
        self._roidb_handler = self.selective_search_roidb

        <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Specific config options</span>
        self.config = {<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'cleanup'</span>  : <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">True</span>,
                       <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'use_salt'</span> : <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">True</span>,
                       <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'top_k'</span>    : <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2000</span>}

        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">assert</span> os.path.exists(self._devkit_path), \
                <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Devkit path does not exist: {}'</span>.format(self._devkit_path)
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">assert</span> os.path.exists(self._data_path), \
                <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Path does not exist: {}'</span>.format(self._data_path)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li></ul>

这里面最要注意的是要根据自己训练的类别同步修改self._classes,我这里面只有两类。

类 inria 构造完成后,会调用函数 roidb,这个函数是从类 imdb 中继承过来的,这个函数会调用 _roidb_handler 来处理,其中 _roidb_handler=self.selective_search_roidb,下面我们来看看这个函数:

<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">selective_search_roidb</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span>
    <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""
    Return the database of selective search regions of interest.
    Ground-truth ROIs are also included.
    This function loads/saves from/to a cache file to speed up future calls.
    """</span>
    cache_file = os.path.join(self.cache_path,
                             self.name + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'_selective_search_roidb.pkl'</span>)

    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> os.path.exists(cache_file):
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(cache_file, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'rb'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> fid:
            roidb = cPickle.load(fid)
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'{} ss roidb loaded from {}'</span>.format(self.name, cache_file)
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> roidb

    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self._image_set != <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'test'</span>:
        gt_roidb = self.gt_roidb()
        ss_roidb = self._load_selective_search_roidb(gt_roidb)
        roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">else</span>:
        roidb = self._load_selective_search_roidb(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>)
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> len(roidb)
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(cache_file, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'wb'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> fid:
        cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'wrote ss roidb to {}'</span>.format(cache_file)

    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> roidb</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li></ul>

这个函数在训练阶段会首先调用get_roidb() 函数:

<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">    <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">gt_roidb</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span>
        <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """</span>
        cache_file = os.path.join(self.cache_path, self.name + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'_gt_roidb.pkl'</span>)
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> os.path.exists(cache_file):
            <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(cache_file, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'rb'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> fid:
                roidb = cPickle.load(fid)
            <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'{} gt roidb loaded from {}'</span>.format(self.name, cache_file)
            <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> roidb

        gt_roidb = [self._load_inria_annotation(index)
                    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> index <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> self.image_index]
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(cache_file, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'wb'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> fid:
            cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'wrote gt roidb to {}'</span>.format(cache_file)

        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> gt_roidb</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li></ul>

如果存在cache_file,那么get_roidb()就会直接从cache_file中读取信息;如果不存在cache_file,那么会调用_load_inria_annotation()来取得标注信息。_load_inria_annotation函数如下所示:

<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">_load_inria_annotation</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, index)</span>:</span>
        <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""
        Load image and bounding boxes info from txt files of INRIA Person.
        """</span>
        filename = os.path.join(self._data_path, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Annotations'</span>, index + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'.xml'</span>)
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Loading: {}'</span>.format(filename)

        <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">get_data_from_tag</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(node, tag)</span>:</span>
            <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> node.getElementsByTagName(tag)[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>].childNodes[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>].data

        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(filename) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> f:
            data = minidom.parseString(f.read())

        objs = data.getElementsByTagName(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'object'</span>)
        num_objs = len(objs)

        boxes = np.zeros((num_objs, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)

        <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Load object bounding boxes into a data frame.</span>
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> ix, obj <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> enumerate(objs):
            <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Make pixel indexes 0-based</span>
            x1 = float(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'xmin'</span>)) - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>
            y1 = float(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'ymin'</span>)) - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>
            x2 = float(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'xmax'</span>)) - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>
            y2 = float(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'ymax'</span>)) - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>
            <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># ---------------------------------------------</span>
            <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># add these lines to avoid the accertion error</span>
            <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> x1 < <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>:
                x1 = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>
            <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> y1 < <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>:
                y1 = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>
            <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># ----------------------------------------------</span>
            cls = self._class_to_ind[
                    str(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"name"</span>)).lower().strip()]
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls
            overlaps[ix, cls] = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.0</span>

        overlaps = scipy.sparse.csr_matrix(overlaps)

        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> {<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'boxes'</span> : boxes,
                <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'gt_classes'</span>: gt_classes,
                <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'gt_overlaps'</span> : overlaps,
                <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'flipped'</span> : <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">False</span>}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li></ul>

当处理完标注的数据后,接下来就要载入SS阶段获得的数据,通过如下函数完成:

<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">    <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">_load_selective_search_roidb</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, gt_roidb)</span>:</span>
        filename = os.path.abspath(os.path.join(self._devkit_path,
                                                self.name + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'.mat'</span>))
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">assert</span> os.path.exists(filename), \
               <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Selective search data not found at: {}'</span>.format(filename)
        raw_data = sio.loadmat(filename)[<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'boxes'</span>].ravel()

        box_list = []
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> i <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> xrange(raw_data.shape[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>]):
            <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#这个地方需要注意,如果在SS中你已经变换了box的值,那么就不需要再改变box值的位置了</span>
            <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#box_list.append(raw_data[i][:, (1, 0, 3, 2)] - 1)</span>
            box_list.append(raw_data[i][:, (<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>)])

    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> self.create_roidb_from_box_list(box_list, gt_roidb)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li></ul>

有一点需要注意的是,ss中获得的box的值,和fast-rcnn中认为的box值有点差别,那就是你需要交换box的x和y坐标。

未完待续……


你可能感兴趣的:(Fast-RCNN解析:训练阶段代码导读)