简介和说明:
1.本文包含两个部分,一个是关于ROIdatabase的函数分析,一个是关于minibatch生成的函数分析;
2.由于直接用文字表述函数之间的调用关系很麻烦,因此我使用标题的等级来表示。标题等级高一级的函数调用下面标题等级低一级的函数,因此目录可以一定程度上体现出函数调用关系
3.为了突出主要部分,文章中列出的函数的有些部分做了些简化和省略,有些地方的函数调用也直接简化为function(...)
的形式;
4.本文分析的代码基于tensorflow,代码可从github上下载,网址为https://github.com/endernewton/tf-faster-rcnn
def parse_args(): #解析参数
def combined_roidb(imdb_names):
if __name__ == '__main__':
def combined_roidb(imdb_names):
def get_roidb(imdb_name): #内部函数
imdb = get_imdb(imdb_name) (1)/lib/tools/factory.py
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD) (2)/lib/dataset/imdb.py
roidb = get_training_roidb(imdb) (3)/lib/model/train_val.py
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')] #the function above
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
tmp = get_imdb(imdb_names.split('+')[1])
imdb = datasets.imdb.imdb(imdb_names, tmp.classes)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
通过给数据集的名字返回该数据集对应的类的对象
由于cfg.TRAIN.PROPOSAL_METHOD = ‘tf’,内部相当于执行了一次gt_roidb函数:
def gt_roidb(): #位于/lib/tools/pascal_voc.py
...
gt_roidb = [self._load_pascal_annotation(index)
for index in self.image_index]
#关键语句,调用同一文件下面的_load_pascal_annotation函数
#该函数从XML文件中加载图片和bbox
...
#最后返回一个字典,包含“boxes”,“gt_classes”等
def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
imdb.append_flipped_images() #通过翻转增加样本数量,位于/lib/dataset/imdb.py
rdl_roidb.prepare_roidb(imdb) (1)/lib/roi_data_layer/roidb.py
return imdb.roidb
def prepare_roidb(imdb): "为roidb加了一些说明性的属性"
for i in range(len(imdb.image_index)):
roidb[i]['image'/'width'/'height'/'max_classes'/'max_overlaps'...]
output:
roidb[img_index]包含的key | value |
---|---|
boxes | box位置信息,box_num*4的np array |
gt_overlaps | 所有box在不同类别的得分,box_num*class_num矩阵 |
gt_classes | 所有box的真实类别,box_num长度的list |
flipped | 是否翻转 |
image | 该图片的路径,字符串 |
width | 图片的宽 |
height | 图片的高 |
max_overlaps | 每个box的在所有类别的得分最大值,box_num长度 |
max_classes | 每个box的得分最高所对应的类,box_num长度 |
bbox_targets | 每个box的类别,以及与最接近的gt-box的4个方位偏移 |
if name == 'main':
args = parse_args() #the function above
...
combined_roidb(...) #the function above
...
#build network
net = vgg16/resnetv1(num_layers=50/101/152)/mobilenetv1 (2.1)/lib/nets/network.py
train_net(...) (2.2)/lib/model/train_val.py
主函数首先调用之前定义的parse_args()、combined_roidb(),每个数据集返回了带有该数据集信息的imdb和每张影像的roidb
接下来(2.1)步构建网络
以vgg16为例
class vgg16(Network): #继承与Network的子类
...
def _build_network(): #核心函数
...
net = slim.repeat/max_pool2d #构建CNN网络
...
# build the anchors for the image
self._anchor_component()
# 构建RPN
rois = self._region_proposal(net, is_training, initializer)
# 构建ROI-Pooling
if cfg.POOLING_MODE == 'crop':
pool5 = self._crop_pool_layer(net, rois, "pool5")
else:
raise NotImplementedError
...
def train_net(...): #训练网络
filter_roidb(...)
#对roi进行筛选,去掉没有用的,筛选标准为:
# Valid images have:
# (1) At least one foreground RoI OR
# (2) At least one background RoI
...
sw = SolverWrapper(...) #构造一个SolverWrapper类的对象,用于训练
sw.train_model(sess, max_iters) #
def train_model(self, sess, max_iters):
# 为训练和验证构造RoIDataLayer的对象
self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) #/lib/roi_data_layer/layer.py
self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)
...
while iter < max_iters + 1:
# 训练数据的时候用随机梯度下降,一次获取一个batch的数据,获取方法调用forward函数
blobs = self.data_layer.forward()
#forward函数位于 /lib/roi_data_layer/layer.py ,内部仅仅调用_get_next_minibatch函数
def _shuffle_roidb_inds(self): #洗牌函数,打乱database顺序
def _get_next_minibatch_inds(self): #如果if条件满足,用shuffle函数打乱顺序并选出新一组batch的index并返回
if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):
self._shuffle_roidb_inds()
def _get_next_minibatch(self):
_get_next_minibatch_inds(...) #得到新一组batch的index
get_minibatch(...) #调用中间级函数,根据上面得到的index读出图像,
def forward(self):
blobs = self._get_next_minibatch() #顶层函数
def get_minibatch(roidb, num_classes):
“根据提供的roidb,调用_get_image_blob读取图像数据,并随机选择构造出一个minibatch样本,被layer.py里的函数调用”
im_blob, im_scales = _get_image_blob(roidb, random_scale_inds)
return blobs
def _get_image_blob(roidb, scale_inds):
for i in range(num_images):
im = cv2.imread(roidb[i]['image'])
prep_im_for_blob(...) #调用最底层文件里的函数
...
blob = im_list_to_blob(processed_ims) #调用最底层文件里的函数
这个文件里的函数主要用于将图像构造成方便训练的blob类型数据结构
def im_list_to_blob(ims):
#输入一个ims,将其转化为4维array形式的blob
def prep_im_for_blob(im, pixel_means, target_size, max_size):
#求取图像的缩放比例,然后将图像resize