Faster-RCNN_TF源码解读——数据预处理

前些天仔细阅读了Faster-RCNN_TF的源码,获益匪浅,写几篇博客记录一下。

第一章:数据预处理
以VOC2007数据集为例,讲述了数据从文件中如何一步步走向神经网络的深渊。

1、./experiments/scripts/faster_rcnn_end2end.sh
从README.md中得知,训练此网络需要调用 faster_rcnn_end2end.sh,如下:

cd $FRCN_ROOT
./experiments/scripts/faster_rcnn_end2end.sh $DEVICE $DEVICE_ID VGG16 pascal_voc

faster_rcnn_end2end.sh
接受参数为DEVICE、DEVICE_ID、VGG16、pascal_voc (这里VGG16作用就是创建了一个log文件以VGG16命名)
根据pascal_voc确定了
TRAIN_IMDB=”voc_2007_trainval”
ITERS=70000

其他参数使用默认值,将这些参数传递了给了./tools/test_net.py:
device gpu
device_id 0
weights data/pretrain_model/VGG_imagenet.npy
imdb “voc_2007_trainval”
iters 70000
cfg experiments/cfgs/faster_rcnn_end2end.yml
network VGGnet_train

2、./tools/train_net.py
(1)对上面传进来的7个参数使用parse_args()进行解析。
(2)之后调用 cfg_from_file(args.cfg_file) 函数生成字典cfg ,主要是加载一些配置参数。
(3)之后调用 imdb = get_imdb(args.imdb_name),读取image database的基本信息
……(今天只介绍数据怎么加载进来的)

3、./lib/datasets/factory.py
get_imdb()函数在 ./datasets/factory.py 文件中,按照各个数据集的year、name分为好几个数据集,这些数据集都会放在 _sets = {} 字典中, _sets = {} 的初始化是在各种for循环中完成的,不过由于传入的
args.imdb_name = “voc_2007_trainval” ,所以我们只需要关注其中一个循环

'''Set up voc__ using selective search "fast" mode'''
for year in ['2007']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        print name
        __sets[name] = (lambda split=split, year=year:
                datasets.pascal_voc(split, year))

4、./lib/datasets/pascal_voc.py

class pascal_voc(imdb):
    def __init__(self, image_set, year, devkit_path=None):
        imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year
        self._image_set = image_set
        self._devkit_path = self._get_default_path() if devkit_path is None \
                            else devkit_path
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
        self._classes = ('__background__', # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        #self._roidb_handler = self.selective_search_roidb
        self._roidb_handler = self.gt_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = {'cleanup'     : True,
                       'use_salt'    : True,
                       'use_diff'    : False,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 2}

        assert os.path.exists(self._devkit_path), \
                'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
                'Path does not exist: {}'.format(self._data_path)

pascal_voc是一个类,他继承了imdb类,在之前的输入参数下使用”trainval”、”2007”进行初始化,其中

        self._devkit_path = self._get_default_path() if devkit_path is None 
                            else devkit_path

确定了VOC_2007数据集的路径,而以下语句写明了类别,而且给每个类别都编了一个号

        self._classes = ('__background__', # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))

self._image_index读取了 ./VOCdevkit2007/VOC2007/ImageSets/Main/train_val.txt中的所有照片的名字(不包含后缀名)

 self._image_index = self._load_image_set_index()

之后读取每张图像的ground-truth boxes信息:

self._roidb_handler = self.gt_roidb
 def gt_roidb(self):
     """
     Return the database of ground-truth regions of interest.

     This function loads/saves from/to a cache file to speed up future calls.
     """
     cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
     if os.path.exists(cache_file):
         with open(cache_file, 'rb') as fid:
             roidb = cPickle.load(fid)
         print '{} gt roidb loaded from {}'.format(self.name, cache_file)
         return roidb

     gt_roidb = [self._load_pascal_annotation(index)
                 for index in self.image_index]
     with open(cache_file, 'wb') as fid:
         cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
     print 'wrote gt roidb to {}'.format(cache_file)

     return gt_roidb

上面的代码挺多,其实只看一句就行,就是调用 self._load_pascal_annotation(index) 那句,他根据每张图像的名字,去xml的标注文件中读取标注。作者使用了xml.etree.ElementTree去读取xml文件的,具体使用方法我改天再写篇博客。其实最后又返回了一个字典:

'''boxes = np.zeros((num_objs, 4), dtype=np.uint16)'''
'''gt_classes = np.zeros((num_objs), dtype=np.int32)'''
''' overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)'''
'''seg_areas = np.zeros((num_objs), dtype=np.float32)'''
'''值得注意的是,overlaps是一个稀疏矩阵,需要转换为其他格式,如下:'''
overlaps = scipy.sparse.csr_matrix(overlaps)

return {'boxes' : boxes,         
        'gt_classes': gt_classes, 
        'gt_overlaps' : overlaps, 
        'flipped' : False,
        'seg_areas' : seg_areas} 

完整代码…

    def _load_pascal_annotation(self, index):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """
        filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')
        if not self.config['use_diff']:
            # Exclude the samples labeled as difficult
            non_diff_objs = [
                obj for obj in objs if int(obj.find('difficult').text) == 0]
            # if len(non_diff_objs) != len(objs):
            #     print 'Removed {} difficult objects'.format(
            #         len(objs) - len(non_diff_objs))
            objs = non_diff_objs
        num_objs = len(objs)

        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
        # "Seg" area for pascal is just the box area
        seg_areas = np.zeros((num_objs), dtype=np.float32)

        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based
            x1 = float(bbox.find('xmin').text) - 1
            y1 = float(bbox.find('ymin').text) - 1
            x2 = float(bbox.find('xmax').text) - 1
            y2 = float(bbox.find('ymax').text) - 1
            cls = self._class_to_ind[obj.find('name').text.lower().strip()]
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls
            overlaps[ix, cls] = 1.0
            seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

        overlaps = scipy.sparse.csr_matrix(overlaps)

        return {'boxes' : boxes,
                'gt_classes': gt_classes,
                'gt_overlaps' : overlaps,
                'flipped' : False,
                'seg_areas' : seg_areas}

5、回到./tools/train_net.py
刚刚通过调用 imdb = get_imdb(args.imdb_name),读取image database的一些基本信息,但是这些信息不足够我们训练一个faster rcnn网络,所以进一步调用了 roidb = get_training_roidb(imdb),进一步丰富数据的信息

6、./lib/fate_rcnn/train.py
roidb = get_training_roidb(imdb)输入上一步的imdb,返回真正用来训练的数据,用作者的话就是:

"""Returns a roidb (Region of Interest database) for use in training."""

函数如下,大概就是现将图片左右翻转以丰富图像数量并且数据增强,然后就开始准备数据了 prepare_roidb(imdb)

def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if cfg.TRAIN.USE_FLIPPED:
        print 'Appending horizontally-flipped training examples...'
        imdb.append_flipped_images()
        print 'done'

    print 'Preparing training data...'
    if cfg.TRAIN.HAS_RPN:
        if cfg.IS_MULTISCALE:
            gdl_roidb.prepare_roidb(imdb)
        else:
            rdl_roidb.prepare_roidb(imdb)
    else:
        rdl_roidb.prepare_roidb(imdb)
    print 'done'

    return imdb.roidb

其中 imdb.append_flipped_images() 只是将之前得到的imbd.roidb里的boxes的坐标换了一下,数据量增加了一倍

    entry = {'boxes' : boxes,
             'gt_overlaps' : self.roidb[i]['gt_overlaps'],
             'gt_classes' : self.roidb[i]['gt_classes'],
             'flipped' : True}
    self.roidb.append(entry)
    self._image_index = self._image_index * 2

在 rdl_roidb.prepare_roidb(imdb) 中,进一步丰富了imdb.roidb的内容,首先读取了图片,知道了每张图片的大小,然后增加了:

        roidb[i]['image'] = imdb.image_path_at(i)    #图片路径
        roidb[i]['width'] = sizes[i][0]     
        roidb[i]['height'] = sizes[i][1]
        # need gt_overlaps as a dense array for argmax
        gt_overlaps = roidb[i]['gt_overlaps'].toarray()    #将稀疏矩阵恢复成一般矩阵
        # max overlap with gt over classes (columns)
        max_overlaps = gt_overlaps.max(axis=1)
        # gt class that had the max overlap
        max_classes = gt_overlaps.argmax(axis=1)
        roidb[i]['max_classes'] = max_classes   #找到这个obj对应最大overlaps的类序号
        roidb[i]['max_overlaps'] = max_overlaps #该obj与这个类的overlaps大小为多少,都是1

到此为止,数据的预处理就已经完成了

7、再次回到./tools/train_net.py
数据的处理都已经完成了,接下来调用 ./lib/fast_rcnn/train.py 中的 train_net()函数
数据还需要进一步滤除无效数据,但是我觉得这个操作应该没啥用…因为所有的overlaps都是1…不会超过阈值(如有不对请指正)

roidb = filter_roidb(roidb)
sw = SolverWrapper(sess, saver, network, imdb, roidb, output_dir, pretrained_model=pretrained_model)

8、./lib/fast_rcnn/train.py
构造 SolverWrapper 对象,产生bbox的回归目标

class SolverWrapper(object):
    """A simple wrapper around Caffe's solver.
    This wrapper gives us control over he snapshotting process, which we
    use to unnormalize the learned bounding-box regression weights.
    """

    def __init__(self, sess, saver, network, imdb, roidb, output_dir, pretrained_model=None):
        """Initialize the SolverWrapper."""
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.output_dir = output_dir
        self.pretrained_model = pretrained_model

        print 'Computing bounding-box regression targets...'
        if cfg.TRAIN.BBOX_REG:
            self.bbox_means, self.bbox_stds = rdl_roidb.add_bbox_regression_targets(roidb)
        print 'done'

        # For checkpoint
        self.saver = saver

9、./lib/roi_data_layer/roidb.py
由于传进来的都是gt,所以 roidb[im_i][‘bbox_targets’] 的值就是传进来的gt的boxes生成的[dx,dy,dw,dh]
返回的均值 means 和 std 没仔细研究怎么用…

def add_bbox_regression_targets(roidb):
    """Add information needed to train bounding-box regressors."""
    assert len(roidb) > 0
    assert 'max_classes' in roidb[0], 'Did you call prepare_roidb first?'

    num_images = len(roidb)
    # Infer number of classes from the number of columns in gt_overlaps
    num_classes = roidb[0]['gt_overlaps'].shape[1]
    for im_i in xrange(num_images):
        rois = roidb[im_i]['boxes']
        max_overlaps = roidb[im_i]['max_overlaps']
        max_classes = roidb[im_i]['max_classes']
        roidb[im_i]['bbox_targets'] = \
                _compute_targets(rois, max_overlaps, max_classes)

    if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
        # Use fixed / precomputed "means" and "stds" instead of empirical values
        means = np.tile(
                np.array(cfg.TRAIN.BBOX_NORMALIZE_MEANS), (num_classes, 1))
        stds = np.tile(
                np.array(cfg.TRAIN.BBOX_NORMALIZE_STDS), (num_classes, 1))
    else:
        # Compute values needed for means and stds
        # var(x) = E(x^2) - E(x)^2
        class_counts = np.zeros((num_classes, 1)) + cfg.EPS
        sums = np.zeros((num_classes, 4))
        squared_sums = np.zeros((num_classes, 4))
        for im_i in xrange(num_images):
            targets = roidb[im_i]['bbox_targets']
            for cls in xrange(1, num_classes):
                cls_inds = np.where(targets[:, 0] == cls)[0]
                if cls_inds.size > 0:
                    class_counts[cls] += cls_inds.size
                    sums[cls, :] += targets[cls_inds, 1:].sum(axis=0)
                    squared_sums[cls, :] += \
                            (targets[cls_inds, 1:] ** 2).sum(axis=0)

        means = sums / class_counts
        stds = np.sqrt(squared_sums / class_counts - means ** 2)

    print 'bbox target means:'
    print means
    print means[1:, :].mean(axis=0) # ignore bg class
    print 'bbox target stdevs:'
    print stds
    print stds[1:, :].mean(axis=0) # ignore bg class

    # Normalize targets
    if cfg.TRAIN.BBOX_NORMALIZE_TARGETS:
        print "Normalizing targets"
        for im_i in xrange(num_images):
            targets = roidb[im_i]['bbox_targets']
            for cls in xrange(1, num_classes):
                cls_inds = np.where(targets[:, 0] == cls)[0]
                roidb[im_i]['bbox_targets'][cls_inds, 1:] -= means[cls, :]
                roidb[im_i]['bbox_targets'][cls_inds, 1:] /= stds[cls, :]
    else:
        print "NOT normalizing targets"

    # These values will be needed for making predictions
    # (the predicts will need to be unnormalized and uncentered)
    return means.ravel(), stds.ravel()

10.、继续 ./lib/fast_rcnn/train.py
建立了 sw = SolverWrapper(sess, saver, network, imdb, roidb, output_dir, pretrained_model=pretrained_model) 之后,执行sw.train_model(sess, max_iters),在该函数的第一行:

data_layer = get_data_layer(self.roidb, self.imdb.num_classes)

11、 ./lib/roi_data_layer/layer.py
就是它,先是打乱数据顺序,然后每次提取出一个batch用于训练

class RoIDataLayer(object):
    """Fast R-CNN data layer used for training."""

    def __init__(self, roidb, num_classes):
        """Set the roidb to be used by this layer during training."""
        self._roidb = roidb
        self._num_classes = num_classes
        self._shuffle_roidb_inds()

12、最后使用feed_dict将数据输入神经网络,关键代码如下:

gt_boxes = np.empty((len(gt_inds), 5), dtype=np.float32)
gt_boxes[:, 0:4] = roidb[0]['boxes'][gt_inds, :] * im_scales[0]
gt_boxes[:, 4] = roidb[0]['gt_classes'][gt_inds]
blobs['gt_boxes'] = gt_boxes
self.gt_boxes = tf.placeholder(tf.float32, shape=[None, 5])
feed_dict={self.net.data: blobs['data'], self.net.im_info: blobs['im_info'], self.net.keep_prob: 0.5,  self.net.gt_boxes: blobs['gt_boxes']}

写到最后有些跳动,懒得一行行敲了,休息

你可能感兴趣的:(深度学习,tensorflow,深度学习)