Faster-rcnn源码解析2

经过了数据准备阶段,得到roidb和 imdb,下面利用得到的数据roidb进入网络的训练阶段:

model_paths = train_net(solver, roidb, output_dir,pretrained_model=init_model,max_iters=max_iters)

# 假设:net_name=‘ZF’,那么solver:ZF_faster_rcnn_alt_opt_stage1_rpn_solver60k80k.pt,

#  max_iters:80000

进入train_net函数:

首先,使用filter_roidb函数对roidb进行过滤,过滤掉无效的图片数据。

然后,创建 SolverWrapper类:

sw = SolverWrapper(solver_prototxt, roidb, output_dir,pretrained_model=pretrained_model)

进入 SolverWrapper类的初始化函数:


Faster-rcnn源码解析2_第1张图片

注意,cfg.TRAIN.BBOX_REG =False。所以,这两个if语句都不执行。


Faster-rcnn源码解析2_第2张图片

self.solver = caffe.SGDSolver(solver_prototxt):利用stage1_rpn_solver60k80k.pt文件初始化网络的优化器

self.solver.net.layers[0].set_roidb(roidb):根据roi_data_layer.layer文件中的set_roidb函数来对roidb进行随机打乱

回到train_net函数,执行:model_paths = sw.train_model(max_iters),返回值是一个列表。

进入 sw.train_model函数:


每200次打印一次结果。


每10000次保存一次网络,注意,在stage1_rpn_solver60k80k.pt文件中,snapshot:0,因此,这里没用使用caffe自带的snapshot来保存网络结果,而是用的自己定义的snapshot。

进入snapshot,发现这个函数返回值是一个:filename文件(保存的网络的绝对路径),因此model_paths 返回的结果是一个filename文件的列表。这也是train_net函数的返回结果。

回到train_rpn函数中:


Faster-rcnn源码解析2_第3张图片

将得到的model_paths 列表中的元组只保留最后一个,其余的全部移除,也就是只保留最新的那个网络结果,然后把这个结果以字典的形式推入进程队列中。

最后,回到:p = mp.Process(target=train_rpn,kwargs=mp_kwargs),注意,这里的p只是创建进程,接下来,我们启动进程:p.start(),从进程队列中取出刚才的字典的value:rpn_stage1_out = mp_queue.get(),然后等待进程结束:p.join()

这样,我们就得到了训练的rnp网络:rpn_stage1_out,然后把它用在下一步中。

你可能感兴趣的:(Faster-rcnn源码解析2)