转载自:
Fast-RCNN解析:训练阶段代码导读 - LinJM-机器视觉 - 博客频道 - CSDN.NET
http://blog.csdn.net/linj_m/article/details/48930179#0-tsina-1-35514-397232819ff9a47a7b7e80a40613cfe1
关于Fast-RCNN的解析,我们将主要分为两个部分来介绍,其中一个是训练部分,这个部分非常重要,是我们需要重点讲解的;另一个是测试部分,这个部分关系到具体的应用,所以也是必须要了解的。本篇博文中,我们先从训练部分讲起。
在官方文档中,训练阶段的启动脚本如下所示:
<code class="hljs avrasm has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">./tools/train_net<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.py</span> --gpu <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> --solver models/VGG16/solver<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.prototxt</span> \ --weights data/imagenet_models/VGG16<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.v</span>2<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">.caffemodel</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li></ul>
从这段脚本中,我们可以知道,训练的入口函数就在train_net.py中,其位于fast-rcnn/tools/文件夹内,我们先来看看这个文件。
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> __name__ == <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'__main__'</span>: args = parse_args() print(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Called with args:'</span>) print(args) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> args.cfg_file <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">is</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>: cfg_from_file(args.cfg_file) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> args.set_cfgs <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">is</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>: cfg_from_list(args.set_cfgs) print(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Using config:'</span>) pprint.pprint(cfg) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> args.randomize: <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># fix the random seeds (numpy and caffe) for reproducibility</span> np.random.seed(cfg.RNG_SEED) caffe.set_random_seed(cfg.RNG_SEED) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># set up caffe</span> caffe.set_mode_gpu() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> args.gpu_id <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">is</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>: caffe.set_device(args.gpu_id) imdb = get_imdb(args.imdb_name) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Loaded dataset `{:s}` for training'</span>.format(imdb.name) roidb = get_training_roidb(imdb) output_dir = get_output_dir(imdb, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Output will be saved to `{:s}`'</span>.format(output_dir) train_net(args.solver, roidb, output_dir, pretrained_model=args.pretrained_model, max_iters=args.max_iters)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li></ul>
从以上的code,我们可以看到,train_net.py的主要处理过程包括以下三个部分:
(1) 首先对启动脚本的输入参数进行处理,是通过如下这个函数parse_args()进行处理的。
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">parse_args</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">()</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">""" Parse input arguments """</span> parser = argparse.ArgumentParser(description=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Train a Fast R-CNN network'</span>) parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--gpu'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'gpu_id'</span>, help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'GPU device id to use [0]'</span>, default=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>, type=int) parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--solver'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'solver'</span>, help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'solver prototxt'</span>, default=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>, type=str) parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--iters'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'max_iters'</span>, help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'number of iterations to train'</span>,default=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">40000</span>, type=int) parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--weights'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'pretrained_model'</span>, help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'initialize with pretrained model weights'</span>, default=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>, type=str) parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--cfg'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'cfg_file'</span>, help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'optional config file'</span>,default=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>, type=str) parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--imdb'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'imdb_name'</span>, help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'dataset to train on'</span>,default=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'voc_2007_trainval'</span>, type=str) parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--rand'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'randomize'</span>, help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'randomize (do not use a fixed seed)'</span>,action=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'store_true'</span>) parser.add_argument(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'--set'</span>, dest=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'set_cfgs'</span>, help=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'set config keys'</span>, default=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>,nargs=argparse.REMAINDER) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> len(sys.argv) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>: parser.print_help() sys.exit(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) args = parser.parse_args() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> args</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li></ul>
从这个函数中,我们可以了解到,训练脚本的可选输入参数包括:
(2) 然后是根据输入的参数(–imdb 参数后面指定的数据)来准备训练样本,这个步骤涉及到两个函数:一个 imdb=get_imdb(args.imdb_name)
, 另一个是roidb=get_training_roidb(imdb)
。关于这两个函数我们下部分会花大时间来解析,这里先不谈。
(3) 最后就是训练函数:train_net(args.solver,roidb, output_dir, pretrained_model= args.pretrained_model, max_iters= args.max_iters)
而这个 train_net() 函数是从 fast_rcnn/lib/fast_rcnn 文件夹中的 train.py 中 import 进来的。那么接下来,我们来看看这个train.py
这个函数主要由一个类SolverWrapper和两个函数get_training_roidb()和train_net()组成。
首先,我们来看看train_net()函数:
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">train_net</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(solver_prototxt, roidb, output_dir, pretrained_model=None, max_iters=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">40000</span>)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""Train a Fast R-CNN network."""</span> sw = SolverWrapper(solver_prototxt, roidb, output_dir, pretrained_model=pretrained_model) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Solving...'</span> sw.train_model(max_iters) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'done solving'</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li></ul>
可以发现,该函数是通过调用类SolverWrapper来实现其主要功能的,因此,我们跟进到类SolverWrapper的类构造函数中去:
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">__init__</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, solver_prototxt, roidb, output_dir, pretrained_model=None)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""Initialize the SolverWrapper."""</span> self.output_dir = output_dir <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Computing bounding-box regression targets...'</span> self.bbox_means, self.bbox_stds = \ rdl_roidb.add_bbox_regression_targets(roidb) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'done'</span> self.solver = caffe.SGDSolver(solver_prototxt) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> pretrained_model <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">is</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>: <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> (<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Loading pretrained model '</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'weights from {:s}'</span>).format(pretrained_model) self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(solver_prototxt, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'rt'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> f: pb2.text_format.Merge(f.read(), self.solver_param) self.solver.net.layers[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>].set_roidb(roidb)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li></ul>
初始化完成后,就是要调用train_model函数来进行网络训练,我们来看一下它的主体部分:
<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">train_model</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, max_iters)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""Network training loop."""</span> last_snapshot_iter = -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> timer = Timer() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">while</span> self.solver.iter < max_iters: <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Make one SGD update</span> timer.tic() self.solver.step(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) timer.toc() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self.solver.iter % (<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span> * self.solver_param.display) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>: <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'speed: {:.3f}s / iter'</span>.format(timer.average_time) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>: last_snapshot_iter = self.solver.iter self.snapshot() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> last_snapshot_iter != self.solver.iter: self.snapshot()</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li></ul>
到此为止,网络就可以开始训练了。
不过,关于Fast-RCNN的重头戏我们其实还没开始——那就是如何准备训练数据。
在上面介绍训练的流程中,与此相关的函数是:imdb= get_imdb(args.imdb_name)
这个函数是从从lib/datasets/文件夹中的factory.py中import进来的,我们来看一下这个函数:
<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">get_imdb</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(name)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"""Get an imdb (image database) by name."""</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">not</span> __sets.has_key(name): <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">raise</span> KeyError(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Unknown dataset: {}'</span>.format(name)) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> __sets[name]()</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul>
这个函数很简单,其实就是根据字典的key来取得训练数据。
那么这个字典是怎么形成的呢?看下面:
<code class="hljs livecodeserver has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">inria_devkit_path = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'/home/jeremy/jWork/frcn/fast-rcnn/data/INRIA/'</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span> <span class="hljs-operator" style="box-sizing: border-box;">in</span> [<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'train'</span>, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'test'</span>]: name = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'{}_{}'</span>.<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">format</span>(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'inria'</span>, <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span>) __sets[name] = (lambda <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span>=<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span>: datasets.inria(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">split</span>, inria_devkit_path))</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li></ul>
它本质上是通过lib/datasets/文件夹下面的inria.py引入的。
所以,现在我们就得开始进入inria.py(这个函数需要我们自己编写,可以参考pascal_voc.py编写)。
首先,我们来看看类inria的构造函数:
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">__init__</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, image_set, devkit_path)</span>:</span> datasets.imdb.__init__(self, image_set) self._image_set = image_set self._devkit_path = devkit_path self._data_path = os.path.join(self._devkit_path, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'data'</span>) self._classes = (<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'__background__'</span>, <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># always index 0</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'1001'</span>) self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) self._image_ext = [<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'.jpg'</span>, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'.png'</span>] self._image_index = self._load_image_set_index() <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Default to roidb handler</span> self._roidb_handler = self.selective_search_roidb <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Specific config options</span> self.config = {<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'cleanup'</span> : <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">True</span>, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'use_salt'</span> : <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">True</span>, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'top_k'</span> : <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2000</span>} <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">assert</span> os.path.exists(self._devkit_path), \ <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Devkit path does not exist: {}'</span>.format(self._devkit_path) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">assert</span> os.path.exists(self._data_path), \ <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Path does not exist: {}'</span>.format(self._data_path)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li></ul>
这里面最要注意的是要根据自己训练的类别同步修改self._classes,我这里面只有两类。
类 inria 构造完成后,会调用函数 roidb,这个函数是从类 imdb 中继承过来的,这个函数会调用 _roidb_handler 来处理,其中 _roidb_handler=self.selective_search_roidb,下面我们来看看这个函数:
<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">selective_search_roidb</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">""" Return the database of selective search regions of interest. Ground-truth ROIs are also included. This function loads/saves from/to a cache file to speed up future calls. """</span> cache_file = os.path.join(self.cache_path, self.name + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'_selective_search_roidb.pkl'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> os.path.exists(cache_file): <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(cache_file, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'rb'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> fid: roidb = cPickle.load(fid) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'{} ss roidb loaded from {}'</span>.format(self.name, cache_file) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> roidb <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> self._image_set != <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'test'</span>: gt_roidb = self.gt_roidb() ss_roidb = self._load_selective_search_roidb(gt_roidb) roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">else</span>: roidb = self._load_selective_search_roidb(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">None</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> len(roidb) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(cache_file, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'wb'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> fid: cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'wrote ss roidb to {}'</span>.format(cache_file) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> roidb</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li></ul>
这个函数在训练阶段会首先调用get_roidb()
函数:
<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">gt_roidb</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">""" Return the database of ground-truth regions of interest. This function loads/saves from/to a cache file to speed up future calls. """</span> cache_file = os.path.join(self.cache_path, self.name + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'_gt_roidb.pkl'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> os.path.exists(cache_file): <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(cache_file, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'rb'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> fid: roidb = cPickle.load(fid) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'{} gt roidb loaded from {}'</span>.format(self.name, cache_file) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> roidb gt_roidb = [self._load_inria_annotation(index) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> index <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> self.image_index] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(cache_file, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'wb'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> fid: cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'wrote gt roidb to {}'</span>.format(cache_file) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> gt_roidb</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li></ul>
如果存在cache_file,那么get_roidb()就会直接从cache_file中读取信息;如果不存在cache_file,那么会调用_load_inria_annotation()来取得标注信息。_load_inria_annotation函数如下所示:
<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">_load_inria_annotation</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, index)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">""" Load image and bounding boxes info from txt files of INRIA Person. """</span> filename = os.path.join(self._data_path, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Annotations'</span>, index + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'.xml'</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">print</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Loading: {}'</span>.format(filename) <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">get_data_from_tag</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(node, tag)</span>:</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> node.getElementsByTagName(tag)[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>].childNodes[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>].data <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">with</span> open(filename) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">as</span> f: data = minidom.parseString(f.read()) objs = data.getElementsByTagName(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'object'</span>) num_objs = len(objs) boxes = np.zeros((num_objs, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>), dtype=np.uint16) gt_classes = np.zeros((num_objs), dtype=np.int32) overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Load object bounding boxes into a data frame.</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> ix, obj <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> enumerate(objs): <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># Make pixel indexes 0-based</span> x1 = float(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'xmin'</span>)) - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> y1 = float(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'ymin'</span>)) - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> x2 = float(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'xmax'</span>)) - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> y2 = float(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'ymax'</span>)) - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># ---------------------------------------------</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># add these lines to avoid the accertion error</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> x1 < <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>: x1 = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> y1 < <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>: y1 = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># ----------------------------------------------</span> cls = self._class_to_ind[ str(get_data_from_tag(obj, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"name"</span>)).lower().strip()] boxes[ix, :] = [x1, y1, x2, y2] gt_classes[ix] = cls overlaps[ix, cls] = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.0</span> overlaps = scipy.sparse.csr_matrix(overlaps) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> {<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'boxes'</span> : boxes, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'gt_classes'</span>: gt_classes, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'gt_overlaps'</span> : overlaps, <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'flipped'</span> : <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">False</span>}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li></ul>
当处理完标注的数据后,接下来就要载入SS阶段获得的数据,通过如下函数完成:
<code class="hljs python has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">_load_selective_search_roidb</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, gt_roidb)</span>:</span> filename = os.path.abspath(os.path.join(self._devkit_path, self.name + <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'.mat'</span>)) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">assert</span> os.path.exists(filename), \ <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'Selective search data not found at: {}'</span>.format(filename) raw_data = sio.loadmat(filename)[<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'boxes'</span>].ravel() box_list = [] <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> i <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">in</span> xrange(raw_data.shape[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>]): <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#这个地方需要注意,如果在SS中你已经变换了box的值,那么就不需要再改变box值的位置了</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#box_list.append(raw_data[i][:, (1, 0, 3, 2)] - 1)</span> box_list.append(raw_data[i][:, (<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>)]) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> self.create_roidb_from_box_list(box_list, gt_roidb)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li></ul>
有一点需要注意的是,ss中获得的box的值,和fast-rcnn中认为的box值有点差别,那就是你需要交换box的x和y坐标。
未完待续……