Fast Rcnn 之数据准备阶段 code 分享

Fast Rcnn 之数据准备阶段 code 分享


1.Rcnn系列简介

rcnn系列是object detection领域经典算法,从rcnn到fast-rcnn再到faster-rcnn,三篇工作都有Ross Girshick大神的重要贡献。关于object detection系列的算法思想介绍,有很多博客介绍的很清晰,推荐 cs231n学习笔记-CNN-目标检测、定位、分割,但是关于fast rcnn或者是faster rcnn工程中的code介绍却不多见。

step1 数据准备阶段(roidb)

整体流程介绍

先从整体出发介绍数据准备阶段的流程框架,在掌握大体框架之后,再去看具体代码功能实现,这样有助于理清头绪和快速掌握,不至于陷入一些恼人的代码细节之中。

初始数据

初始数据包括 image_index, groundtruth_annotation, selective_search_box
1. image_index即是 image names
2. gt_annotation为人工标注的box位置,每一个box为四元组
3. selective_search_box为 offline 计算好的 proposal box

大体流程
  1. 读每个 image 对应的 gt_annotation,将 [box_location, gt_class, gt_overlap …] 等重要信息存入 roidb。
  2. 读每个 image 对应的 selective_box,将 [box_location, gt_class, gt_overlap …] 等重要信息存入 roidb。
  3. 水平翻转图片做 data augmentation;存入图片的路径,为训练时读图做准备。
  4. 计算每个box的回归目标;以类别为粒度,对 box 信息进行归一化。
  5. 初始化网络,将 roidb 送入第一层。

具体详解

启动训练的脚本是 ./tools/train_net.py,按照代码顺序,解释其数据准备的几个关键流程。

### ./tools/train_net.py
if __name__ == '__main__':
    args = parse_args() # 1.参数解析

    print('Called with args:')
    print(args)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    print('Using config:')
    pprint.pprint(cfg)

    if not args.randomize:
        # fix the random seeds (numpy and caffe) for reproducibility
        np.random.seed(cfg.RNG_SEED)
        caffe.set_random_seed(cfg.RNG_SEED)

    # set up caffe
    caffe.set_mode_gpu()
    if args.gpu_id is not None:
        caffe.set_device(args.gpu_id)

    imdb = get_imdb(args.imdb_name) # 2.产生roidb数据集
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    roidb = get_training_roidb(imdb) # 3.为roidb准备训练时所需信息

    output_dir = get_output_dir(imdb, None)
    print 'Output will be saved to `{:s}`'.format(output_dir)

    train_net(args.solver, roidb, output_dir, # 4.设定参数并训练。
              pretrained_model=args.pretrained_model,
              max_iters=args.max_iters)

1.参数解析

调用python的参数解析模块argparse,解析的信息包括 [gpu id, solver text, weight(pre-trained model, ….)]。具体实现在 ./tools/train_net.pydef parse_args():中。

2.产生roidb数据集

imdb = get_imdb(args.imdb_name) # 2.产生roidb数据集,其中get_imdb()是从./lib/datasets/factory.py中导入。具体代码为

### ./lib/datasets/factory.py
# Set up voc__ using selective search "fast" mode
for year in ['2007', '2012']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year:
                datasets.pascal_voc(split, year))
... #省略部分代码
...
def get_imdb(name):
    """Get an imdb (image database) by name."""
    if not __sets.has_key(name):
        raise KeyError('Unknown dataset: {}'.format(name))
    return __sets[name]()

由此可看出,get_imdb()返回的是datasets.pascal_voc(split, year)的一个实例对象,此类从./lib/datasets/pascal_voc.py中导入。具体代码,

### ./lib/datasets/pascal_voc.py
class pascal_voc(datasets.imdb):
    def __init__(self, image_set, year, devkit_path=None):
        datasets.imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year
        self._image_set = image_set
        self._devkit_path = self._get_default_path() if devkit_path is None \
                            else devkit_path
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year) # 初始数据的根路径
        self._classes = ('__background__', # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) # 没个类别对应一个整数
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index() # 图片的index(即names)
        # Default to roidb handler
        self._roidb_handler = self.selective_search_roidb # gt_box 和 selective_box 共同构造出 roidb

        # PASCAL specific config options
        self.config = {'cleanup'  : True,
                       'use_salt' : True,
                       'top_k'    : 2000}

        assert os.path.exists(self._devkit_path), \
                'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
                'Path does not exist: {}'.format(self._data_path)

重点介绍其中self._roidb_handler = self.selective_search_roidb # gt_box 和 selective_box 共同构造出 roidb,此过程实际分三步,首先根据 gt_annotation 构造 gt_roidb,然后依据 selective_box 构造ss_roidb,最后将而这合并为一个 roidb
具体代码如下,

### ./lib/datasets/pascal_voc.py
    def selective_search_roidb(self):
    ... #省略部分代码
            gt_roidb = self.gt_roidb()
            ss_roidb = self._load_selective_search_roidb(gt_roidb)
            roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
    ... #省略部分代码

gt_roidb, ss_roidb都是以 image 为粒度进行构造( 即 gt_roidb[img_index] 为具体内容),具体过程都是读每张图片对应的标注数据,按照键值对形式进行存储,具体格式如下:

{'boxes' : boxes, # [ [x1, y1, x2, y2] [x1`, y1`, x2`, y2`] ... ] 每一个标注的box,对应一个四元组
'gt_classes': gt_classes, # [ label_1 label_2 ... ] 每一个box对应一个 label(21种类别之一)
'gt_overlaps' : overlaps, # [ [lap_0 lap_1 ... lap_20] [lap_0` lap_1` ... lap_20`] ... ] 每一个box对应一个21维元组,表示和每一种类别的重合度(gt_annotation时,则于某一类别重合度最高,为1),
'flipped' : False} # 是否使用水平翻转之后图片,用做 data augmentation

合并过程,同样以 image 为粒度,将 gt_roidb[img_index], ss_roidb[img_index] 连接合并。至此,roidb已基本准备完毕。

3.为roidb准备训练时所需信息

roidb = get_training_roidb(imdb) # 3.为roidb准备训练时所需信息,此函数由 ./lib/fast_rcnn/train.py中导入。具体如下:

### ./lib/fast_rcnn/train.py
def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if cfg.TRAIN.USE_FLIPPED:
        print 'Appending horizontally-flipped training examples...'
        imdb.append_flipped_images() # 水平翻转图片,data augmentation
        print 'done'

    print 'Preparing training data...'
    rdl_roidb.prepare_roidb(imdb) # 准备训练时用到的信息
    print 'done'

    return imdb.roidb

rdl_roidb.prepare_roidb(imdb) # 准备训练时用到的信息 中重要的一步是添加每张imageimage_path。好吧,到这里,估计你也注意到了,roi_db准备了一大堆,并没有读图,只是有图片路径和标注数据。

4.设定参数并训练。

train_net(args.solver, roidb, output_dir, # 4.设定参数并训练。....,函数在./lib/fast_rcnn/train.py中,

### ./lib/fast_rcnn/train.py
def train_net(solver_prototxt, roidb, output_dir,
              pretrained_model=None, max_iters=40000):
    """Train a Fast R-CNN network."""
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                       pretrained_model=pretrained_model) #调用pycaffe接口,完成网络配置

    print 'Solving...'
    sw.train_model(max_iters) #迭代训练
    print 'done solving'

网络配置初始化阶段如下,

### ./lib/fast_rcnn/train.py
class SolverWrapper(object):
    """A simple wrapper around Caffe's solver.
    This wrapper gives us control over he snapshotting process, which we
    use to unnormalize the learned bounding-box regression weights.
    """
    def __init__(self, solver_prototxt, roidb, output_dir,
                 pretrained_model=None):
        """Initialize the SolverWrapper."""
        self.output_dir = output_dir

        print 'Computing bounding-box regression targets...'
        self.bbox_means, self.bbox_stds = \
                rdl_roidb.add_bbox_regression_targets(roidb) #为每个box(gt, ss都有)计算其应该的回归目标
        print 'done'

        self.solver = caffe.SGDSolver(solver_prototxt)
        if pretrained_model is not None:
            print ('Loading pretrained model '
                   'weights from {:s}').format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)

        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)

        self.solver.net.layers[0].set_roidb(roidb) #将roidb输入第一层

rdl_roidb.add_bbox_regression_targets(roidb) #为每个box(gt, ss都有)计算其应该的回归目标中,简单来说,groudtruth box的回归目标就是自身; 当某些selective boxgroudtruth box重合度达到一定threshold时,则该selective box的回归目标就是此 groudtruth box(多个时,取重合度max的); 其余selective box的回归目标就是0了,因为其标签为 __back-ground__。然后以类别为粒度,将 box 信息进行归一化。
self.solver.net.layers[0].set_roidb(roidb) #将roidb输入第一层则是将准备好的roidb送入网络的第一层。

layer {
  name: 'data'
  type: 'Python'
  top: 'data'
  top: 'rois'
  top: 'labels'
  top: 'bbox_targets'
  top: 'bbox_loss_weights'
  python_param {
    module: 'roi_data_layer.layer'
    layer: 'RoIDataLayer'
    param_str: "'num_classes': 21"
  }
}

中国科学技术大学多媒体计算与通信教育部-微软重点实验室

MultiMedia Computing Group

我们的主页

你可能感兴趣的:(Deep,Learning,coding,算法)