faster rcnn中train.py

这是一个简单的solver包装类,主要是为了实现自己的snapshot,值得一提的地方不是太多,主要是为了读者从头到尾的训练过程理解更加连贯,所以为此文单独开一节源码分析。

class SolverWrapper(object):
"""A simple wrapper around Caffe's solver.
This wrapper gives us control over he snapshotting process, which we
use to unnormalize the learned bounding-box regression weights.
"""

def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
    """Initialize the SolverWrapper."""
    self.output_dir = output_dir


    if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
    cfg.TRAIN.BBOX_NORMALIZE_TARGETS):

    # RPN can only use precomputed normalization because there are no
    # fixed statistics to compute a priori
    assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED

    if cfg.TRAIN.BBOX_REG:
    print 'Computing bounding-box regression targets...'
    self.bbox_means, self.bbox_stds = \
    rdl_roidb.add_bbox_regression_targets(roidb)
    print 'done'

    self.solver = caffe.SGDSolver(solver_prototxt)
    if pretrained_model is not None:
    print ('Loading pretrained model '
    'weights from {:s}').format(pretrained_model)
    self.solver.net.copy_from(pretrained_model)

    self.solver_param = caffe_pb2.SolverParameter()
    with open(solver_prototxt, 'rt') as f:
    pb2.text_format.Merge(f.read(), self.solver_param)

    #所有的前面的数据准备工作都是为了这一句话,将roidb设置进去,接下来就正式进入剖析训练过程的部分了。\
    self.solver.net.layers[0].set_roidb(roidb)
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

snapshot

自主实现了snapshot,精读的意义不大。

def snapshot(self):
    """Take a snapshot of the network after unnormalizing the learned
    bounding-box regression weights. This enables easy use at test-time.
    """
    net = self.solver.net

    scale_bbox_params = (cfg.TRAIN.BBOX_REG and
    cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
    net.params.has_key('bbox_pred'))

    if scale_bbox_params:
    # save original values
    orig_0 = net.params['bbox_pred'][0].data.copy()
    orig_1 = net.params['bbox_pred'][1].data.copy()

    # scale and shift with bbox reg unnormalization; then save snapshot
    net.params['bbox_pred'][0].data[...] = \
    (net.params['bbox_pred'][0].data *
    self.bbox_stds[:, np.newaxis])
    net.params['bbox_pred'][1].data[...] = \
    (net.params['bbox_pred'][1].data *
    self.bbox_stds + self.bbox_means)

    infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
    if cfg.TRAIN.SNAPSHOT_INFIX != else )
    filename = (self.solver_param.snapshot_prefix + infix +
    '_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
    filename = os.path.join(self.output_dir, filename)

    net.save(str(filename))
    print 'Wrote snapshot to: {:s}'.format(filename)

    if scale_bbox_params:
    # restore net to original state
    net.params['bbox_pred'][0].data[...] = orig_0
    net.params['bbox_pred'][1].data[...] = orig_1
    return filename
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

train_model

训练主流程,打印了一些时间等信息,并控制了snapshot的过程。

def train_model(self, max_iters):
    """Network training loop."""
    last_snapshot_iter = -1
    timer = Timer()
    model_paths = []
    while self.solver.iter < max_iters:
    # Make one SGD update
    timer.tic()
    self.solver.step(1)
    timer.toc()
    if self.solver.iter % (10 * self.solver_param.display) == 0:
    print 'speed: {:.3f}s / iter'.format(timer.average_time)

    if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
    last_snapshot_iter = self.solver.iter
    model_paths.append(self.snapshot())

    if last_snapshot_iter != self.solver.iter:
    model_paths.append(self.snapshot())
    return model_paths
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

get_training_roidb

这个函数(如果设置了)将roidb中的每张图片水平翻转,并添加回去,减少了过拟合的可能性,以及调用prepare_roidb做了些准备性的工作。

def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if cfg.TRAIN.USE_FLIPPED:
    print 'Appending horizontally-flipped training examples...'
    imdb.append_flipped_images()
    print 'done'

    print 'Preparing training data...'
    rdl_roidb.prepare_roidb(imdb)
    print 'done'

    return imdb.roidb
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

filter_roidb

该函数中定义了一个is_valid函数,用于判断roidb中的每个entry是否合理,合 理定义为至少有一个前景box或背景box。 
roidb全是groudtruth时,因为box与对应的类的重合度(overlaps)显然为1,也就是说roidb起码要有一个标记类。 
如果roidb包含了一些proposal,overlaps在[BG_THRESH_LO, BG_THRESH_HI]之间的都将被认为是背景,大于FG_THRESH才被认为是前景,roidb 至少要有一个前景或背景,否则将被过滤掉。 
将没用的roidb过滤掉以后,返回的就是filtered_roidb

def filter_roidb(roidb):
"""Remove roidb entries that have no usable RoIs."""

    def is_valid(entry):
        # Valid images have:
        #   (1) At least one foreground RoI OR
        #   (2) At least one background RoI
        overlaps = entry['max_overlaps']
        # find boxes with sufficient overlap
        fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
        # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
        bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
        (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
        # image is only valid if such boxes exist
        valid = len(fg_inds) > 0 or len(bg_inds) > 0
    return valid

    num = len(roidb)
    filtered_roidb = [entry for entry in roidb if is_valid(entry)]
    num_after = len(filtered_roidb)
    print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
    num, num_after)
    return filtered_roidb
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

train_net

该函数通过接收不同的solver以及数据进行网络的训练

def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=40000):
    """Train a Fast R-CNN network."""

    roidb = filter_roidb(roidb)
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
    pretrained_model=pretrained_model)

    print 'Solving...'
    model_paths = sw.train_model(max_iters)
    print 'done solving'
    return model_paths

你可能感兴趣的:(faster,rcnn)