Faster R-CNN tensorflow代码详解

研究背景

根据Faster-RCNN算法的运行和调试情况,对代码进行深入分析。

参考资料

Faster R-CNN:tf-faster-rcnn代码结构
分析参考1
分析参考2
Faster RCNN整体流程
Faster RCNN算法详解
Faster-Rcnn demo.py解析
Faster R-CNN的训练过程的理解

各部分代码分析

1 编译Cython模块

cd tf-faster-rcnn/lib  # 首先进入目录Faster-RCNN_TF/lib
make clean
make  #编译

编译成功之后,目录tf-faster-rcnn/lib/nms 和tf-faster-rcnn/lib/roi_pooling_layer/ 和tf-faster-rcnn/lib/utils下面会出现一些.so文件。
注意:.so文件不具可移植到性,因为编译生成的文件是只适应本台计算机的,换一台计算机之后,用原来的.so文件程序会出错。并且,必须要先删除旧的.so文件make clean,否则就会调用旧的.so文件,而不生成新的.so文件。重新运行程序的时候,要先删除这几个.so文件,并重新进行编译。

2 pascal_voc数据集的数据读写接口

2.1 工程文件tf-faster-rcnn中读取数据的接口都在目录tf-faster-rcnn/lib/datasets下。共有2种数据来训练网络,分别是pascal_voc和coco,数据读写接口分别是tf-faster-rcnn/lib/datasets中的pascal_voc.py和coco.py。

工程主要用到的是目录Annotations中的XML文件、目录JPEGImages中的图片、目录ImageSets/Layout中的txt文件。
目录下其他文件:
factory.py:是个工厂类,用类生成imdb类并且返回数据库供网络训练和测试使用;
imdb.py:是数据库读写类的基类,分装了许多db的操作,具体的一些文件读写需要继承继续读写。

VOCdevkit/
VOCdevkit/VOC2007/
VOCdevkit/VOC2007/Annotations #所有图片的XML文件,一张图片对应一个XML文件,XML文件中给出的图片gt的形式是左上角和右下角的坐标
VOCdevkit/VOC2007/ImageSets/
VOCdevkit/VOC2007/ImageSets/Layout #里面有三个txt文件,分别是train.txt,trainval.txt,val.txt,存储的分别是训练图片的名字列表,训练验证集的图片名字列表,验证集图片的名字列表(名字均没有.jpg后缀)
VOCdevkit/VOC2007/ImageSets/Main
VOCdevkit/VOC2007/ImageSets/Segmentation
VOCdevkit/VOC2007/JPEGImages #所有的图片*.jpg
VOCdevkit/VOC2007/SegmentationClass #segmentations by class
VOCdevkit/VOC2007/SegmentationObject #segmentations by object

2.2 pascal_voc的数据读写接口

  • 主函数 if name == ‘main’在文件pascal_voc.py的最下面
if __name__ == '__main__':
    from datasets.pascal_voc import pascal_voc
    d = pascal_voc('trainval', '2007')   #pascal_voc是一个类
    res = d.roidb
    from IPython import embed; 
    embed()
  • 主函数中的类 pascal_voc代码,在文件pascal_voc.py的最上面:
class pascal_voc(imdb):
  def __init__(self, image_set, year, use_diff=False):
    name = 'voc_' + year + '_' + image_set
    if use_diff:
      name += '_diff'
    imdb.__init__(self, name)
    self._year = year
    self._image_set = image_set
    self._devkit_path = self._get_default_path()
    self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
    self._classes = ('__background__',  # always index 0
                     'aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor')
    self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))
    self._image_ext = '.jpg'
    self._image_index = self._load_image_set_index()
    # Default to roidb handler
    self._roidb_handler = self.gt_roidb
    self._salt = str(uuid.uuid4())
    self._comp_id = 'comp4'

    # PASCAL specific config options
    self.config = {'cleanup': True,
                   'use_salt': True,
                   'use_diff': use_diff,
                   'matlab_eval': False,
                   'rpn_file': None}

    assert os.path.exists(self._devkit_path), \
      'VOCdevkit path does not exist: {}'.format(self._devkit_path)
    assert os.path.exists(self._data_path), \
      'Path does not exist: {}'.format(self._data_path)

  def image_path_at(self, i):
    """
    Return the absolute path to image i in the image sequence.
    """
    return self.image_path_from_index(self._image_index[i])

  def image_path_from_index(self, index):
    """
    Construct an image path from the image's "index" identifier.
    """
    image_path = os.path.join(self._data_path, 'JPEGImages',
                              index + self._image_ext)
    assert os.path.exists(image_path), \
      'Path does not exist: {}'.format(image_path)
    return image_path

  def _load_image_set_index(self):
    """
    Load the indexes listed in this dataset's image set file.
    """
    # Example path to image set file:
    # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
    image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                  self._image_set + '.txt')
    assert os.path.exists(image_set_file), \
      'Path does not exist: {}'.format(image_set_file)
    with open(image_set_file) as f:
      image_index = [x.strip() for x in f.readlines()]
    return image_index

  def _get_default_path(self):
    """
    Return the default path where PASCAL VOC is expected to be installed.
    """
    return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)

  def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
    if os.path.exists(cache_file):
      with open(cache_file, 'rb') as fid:
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb

    gt_roidb = [self._load_pascal_annotation(index)
                for index in self.image_index]
    with open(cache_file, 'wb') as fid:
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb

  def rpn_roidb(self):
    if int(self._year) == 2007 or self._image_set != 'test':
      gt_roidb = self.gt_roidb()
      rpn_roidb = self._load_rpn_roidb(gt_roidb)
      roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
    else:
      roidb = self._load_rpn_roidb(None)

    return roidb

  def _load_rpn_roidb(self, gt_roidb):
    filename = self.config['rpn_file']
    print('loading {}'.format(filename))
    assert os.path.exists(filename), \
      'rpn data not found at: {}'.format(filename)
    with open(filename, 'rb') as f:
      box_list = pickle.load(f)
    return self.create_roidb_from_box_list(box_list, gt_roidb)

  def _load_pascal_annotation(self, index):
    """
    Load image and bounding boxes info from XML file in the PASCAL VOC
    format.
    """
    filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
    tree = ET.parse(filename)
    objs = tree.findall('object')
    if not self.config['use_diff']:
      # Exclude the samples labeled as difficult
      non_diff_objs = [
        obj for obj in objs if int(obj.find('difficult').text) == 0]
      # if len(non_diff_objs) != len(objs):
      #     print 'Removed {} difficult objects'.format(
      #         len(objs) - len(non_diff_objs))
      objs = non_diff_objs
    num_objs = len(objs)

    boxes = np.zeros((num_objs, 4), dtype=np.uint16)
    gt_classes = np.zeros((num_objs), dtype=np.int32)
    overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
    # "Seg" area for pascal is just the box area
    seg_areas = np.zeros((num_objs), dtype=np.float32)

    # Load object bounding boxes into a data frame.
    for ix, obj in enumerate(objs):
      bbox = obj.find('bndbox')
      # Make pixel indexes 0-based
      x1 = float(bbox.find('xmin').text) - 1
      y1 = float(bbox.find('ymin').text) - 1
      x2 = float(bbox.find('xmax').text) - 1
      y2 = float(bbox.find('ymax').text) - 1
      cls = self._class_to_ind[obj.find('name').text.lower().strip()]
      boxes[ix, :] = [x1, y1, x2, y2]
      gt_classes[ix] = cls
      overlaps[ix, cls] = 1.0
      seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

    overlaps = scipy.sparse.csr_matrix(overlaps)

    return {'boxes': boxes,
            'gt_classes': gt_classes,
            'gt_overlaps': overlaps,
            'flipped': False,
            'seg_areas': seg_areas}

  def _get_comp_id(self):
    comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
               else self._comp_id)
    return comp_id

  def _get_voc_results_file_template(self):
    # VOCdevkit/results/VOC2007/Main/_det_test_aeroplane.txt
    filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
    path = os.path.join(
      self._devkit_path,
      'results',
      'VOC' + self._year,
      'Main',
      filename)
    return path

  def _write_voc_results_file(self, all_boxes):
    for cls_ind, cls in enumerate(self.classes):
      if cls == '__background__':
        continue
      print('Writing {} VOC results file'.format(cls))
      filename = self._get_voc_results_file_template().format(cls)
      with open(filename, 'wt') as f:
        for im_ind, index in enumerate(self.image_index):
          dets = all_boxes[cls_ind][im_ind]
          if dets == []:
            continue
          # the VOCdevkit expects 1-based indices
          for k in range(dets.shape[0]):
            f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                    format(index, dets[k, -1],
                           dets[k, 0] + 1, dets[k, 1] + 1,
                           dets[k, 2] + 1, dets[k, 3] + 1))

  def _do_python_eval(self, output_dir='output'):
    annopath = os.path.join(
      self._devkit_path,
      'VOC' + self._year,
      'Annotations',
      '{:s}.xml')
    imagesetfile = os.path.join(
      self._devkit_path,
      'VOC' + self._year,
      'ImageSets',
      'Main',
      self._image_set + '.txt')
    cachedir = os.path.join(self._devkit_path, 'annotations_cache')
    aps = []
    # The PASCAL VOC metric changed in 2010
    use_07_metric = True if int(self._year) < 2010 else False
    print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
    if not os.path.isdir(output_dir):
      os.mkdir(output_dir)
    for i, cls in enumerate(self._classes):
      if cls == '__background__':
        continue
      filename = self._get_voc_results_file_template().format(cls)
      rec, prec, ap = voc_eval(
        filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
        use_07_metric=use_07_metric, use_diff=self.config['use_diff'])
      aps += [ap]
      print(('AP for {} = {:.4f}'.format(cls, ap)))
      with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
        pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
    print(('Mean AP = {:.4f}'.format(np.mean(aps))))
    print('~~~~~~~~')
    print('Results:')
    for ap in aps:
      print(('{:.3f}'.format(ap)))
    print(('{:.3f}'.format(np.mean(aps))))
    print('~~~~~~~~')
    print('')
    print('--------------------------------------------------------------')
    print('Results computed with the **unofficial** Python eval code.')
    print('Results should be very close to the official MATLAB eval code.')
    print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
    print('-- Thanks, The Management')
    print('--------------------------------------------------------------')

  def _do_matlab_eval(self, output_dir='output'):
    print('-----------------------------------------------------')
    print('Computing results with the official MATLAB eval code.')
    print('-----------------------------------------------------')
    path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
                        'VOCdevkit-matlab-wrapper')
    cmd = 'cd {} && '.format(path)
    cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)
    cmd += '-r "dbstop if error; '
    cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \
      .format(self._devkit_path, self._get_comp_id(),
              self._image_set, output_dir)
    print(('Running:\n{}'.format(cmd)))
    status = subprocess.call(cmd, shell=True)

  def evaluate_detections(self, all_boxes, output_dir):
    self._write_voc_results_file(all_boxes)
    self._do_python_eval(output_dir)
    if self.config['matlab_eval']:
      self._do_matlab_eval(output_dir)
    if self.config['cleanup']:
      for cls in self._classes:
        if cls == '__background__':
          continue
        filename = self._get_voc_results_file_template().format(cls)
        os.remove(filename)

  def competition_mode(self, on):
    if on:
      self.config['use_salt'] = False
      self.config['cleanup'] = False
    else:
      self.config['use_salt'] = True
      self.config['cleanup'] = True
  • init是初始化函数,对应着的是pascal_voc的数据集访问格式
  def __init__(self, image_set, year, use_diff=False):
    name = 'voc_' + year + '_' + image_set
    if use_diff:
      name += '_diff'
    imdb.__init__(self, name)  #继承了类imdb的初始化函数__init__(),传进去的参数是voc_2007_train。类imdb在lib/datasets/imdb.py里面被定义
    self._year = year #是一个str,是VOC数据的年份,值是'2007'或者'2012',以2007为例
    self._image_set = image_set #是一个str,值是'train'或者'test'或者'trainval'或者'val',表示的意思是用(训练集)或者(测试集)或者(训练验证集)或者(验证集)里面的数据,以train为例
    self._devkit_path = self._get_default_path() #调用def _get_default_path(self)        路径data/VOCdevkit/VOC2007
    self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)#VOC2007
    self._classes = ('__background__',  # always index 0
                     'aeroplane', 'bicycle', 'bird', 'boat',  
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor') 
    #数据集中所包含的全部的object类别
    self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))  
# 构建字典{'__background__':'0','aeroplane':'1', 'bicycle':'2', 'bird':'3', 'boat':'4','bottle':'5', 'bus':'6', 'car':'7', 'cat':'8', 'chair':'9','cow':'10', 'diningtable':'11', 'dog':'12', 'horse':'13','motorbike':'14', 'person':'15', 'pottedplant':'16','sheep':'17', 'sofa':'18', 'train':'19', 'tvmonitor':'20'}  self.num_classes是object的类别总数21(背景background也算一类),这个函数继承自lib/datasets/imdb.py
    self._image_ext = '.jpg'  # 图片后缀名
    self._image_index = self._load_image_set_index()  #加载了样本的list文件
    # Default to roidb handler
    self._roidb_handler = self.gt_roidb   # 当有RPN的时候,读取并返回图片gt的db。函数gt_roidb里面并没有提取图片的ROI,因为faster-rcnn有RPN,用RPN来提取ROI。函数gt_roidb返回的是图片的gt。(fast-rcnn没有RPN)
    self._salt = str(uuid.uuid4())
    self._comp_id = 'comp4'

    # PASCAL specific config options
    self.config = {'cleanup': True,
                   'use_salt': True,
                   'use_diff': use_diff,
                   'matlab_eval': False,
                   'rpn_file': None}

    assert os.path.exists(self._devkit_path), \
      'VOCdevkit path does not exist: {}'.format(self._devkit_path) #如果路径self._devkit_path(也就是目录VOCdevkit)不存在,退出
    assert os.path.exists(self._data_path), \
      'Path does not exist: {}'.format(self._data_path)#如果路径self._data_path(也就是VOCdevkit/VOC2007)不存在,退出
  • 子函数def _get_default_path(self)
   def _get_default_path(self):
    """
    Return the default path where PASCAL VOC is expected to be installed.
返回数据集pascal_voc的默认路径:tf-faster-rcnn/data/VOCdevkit/2007
    """
    return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)#cfg.DATA_DIR是在tf-faster-rcnn/lib/model/config.py里面定义

tf-faster-rcnn/lib/model/config.py中定义DATA_DIR的地方是这样的(在257-261行):

# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))

# Data directory
__C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data'))
  • 子函数def _load_image_set_index(self)
  def _load_image_set_index(self):
    """
    Load the indexes listed in this dataset's image set file.
    """
    # Example path to image set file:
    # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
    image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                  self._image_set + '.txt')
 # image_set_file就是tf-faster-rcnn/data/VOCdevkit2007/VOC2007/ImageSets/Layout/train.txt
 #之所以要读这个train.txt文件,是因为train.txt文件里面写的是集合train中所有图片的名字(没有后缀.jpg)
    assert os.path.exists(image_set_file), \
      'Path does not exist: {}'.format(image_set_file)
    with open(image_set_file) as f:  # 读上面的train.txt文件
      image_index = [x.strip() for x in f.readlines()]  #将train.txt的内容(图片名字)读取出来放在image_index里面
    return image_index  #得到image_set里面所有图片的名字(没有后缀.jpg)

得到一个list,这个list里面是集合self._image_set中所有图片的名字(注意,图片名字没有后缀.jpg)

  • 子函数def gt_roidb(self)
  def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')   
#给.pkl文件起个名字。参数self.cache_path和self.name继承自类imdb,类imdb在lib/datasets/imdb.py里面定义
    if os.path.exists(cache_file):  # 如果这个.pkl文件存在(说明之前执行过本函数,生成了这个pkl文件)即预处理模型pretrain model
      with open(cache_file, 'rb') as fid:  #打开
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')  #将里面的数据加载进来
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb  #返回

    gt_roidb = [self._load_pascal_annotation(index)    # 如果这个.pkl文件不存在,说明是第一次执行本函数。
                for index in self.image_index]  #那么首先要做的就是获取图片的gt,函数_load_pascal_annotation的作用就是获取图片gt。
    with open(cache_file, 'wb') as fid:    pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)  #将图片的gt保存在.pkl文件里面
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb

读取并返回图片gt的db。这个函数就是将图片的gt加载进来。
其中,pascal_voc图片的gt信息在XML文件中;并且,图片的gt被提前放在了一个.pkl文件里面。(这个.pkl文件需要我们自己生成,代码就在该函数中)之所以会将图片的gt提前放在一个.pkl文件里面,是为了不用每次都再重新读图片的gt,直接加载这个文件就可以了,可以提升速度。
参数self.cache_path和self.name继承自类imdb,类imdb在tf-faster-rcnn/lib/datasets/imdb.py里面被定义。类imdb中定义函数self.cache_path的地方在imdb.py中的77-82行:

@property
  def name(self):
    return self._name

@property
  def cache_path(self):
    cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache'))
    if not os.path.exists(cache_path):
      os.makedirs(cache_path)
    return cache_path

类imdb中定义函数self.name的地方在imdb.py中的23-35行:

  def __init__(self, name, classes=None):  #是类imdb的初始化函数,在pascal_voc.py被用到
    self._name = name   # name是形参,传进来的参数是'voc_2007_train' or ‘voc_2007_test’ or 'voc_2007_val' or 'voc_2007_trainval'
    self._num_classes = 0
    if not classes:
      self._classes = []  #类imdb中定义函数self.name的地方
    else:
      self._classes = classes
    self._image_index = []
    self._obj_proposer = 'gt'
    self._roidb = None
    self._roidb_handler = self.default_roidb
    # Use this dict for storing dataset specific config options
    self.config = {}

@property
  def name(self):  #类imdb中定义函数self.name的地方
    return self._name  #返回的是本文件imdb.py中的self._name

注意:如果再次训练的时候修改了train数据库,增加或者删除了一些数据,再想重新训练的时候,一定要先删除这个output中的.pkl文件。因为如果不删除的话,就会自动加载旧的pkl文件,而不会生成新的pkl文件。

  • 子函数def _load_pascal_annotation(self, index),这个函数是读取图片gt的具体实现
  def _load_pascal_annotation(self, index):
    """
    Load image and bounding boxes info from XML file in the PASCAL VOC
    format.
   从XML文件中获取图片信息和gt。
   这个XML文件存储的是PASCAL VOC图片的信息和gt的信息,下载VOC数据集的时候,XML文件是一块下载下来的。在文件夹Annotation里面。
    """
    filename = os.path.join(self._data_path, 'Annotations', index + '.xml')  
#这个filename就是一个XML文件的路径,其中index是一张图片的名字(没有后缀)。例如VOCdevkit2007/VOC2007/Annotations/000005.xml
    tree = ET.parse(filename)
    objs = tree.findall('object')
    if not self.config['use_diff']:
      # Exclude the samples labeled as difficult
      non_diff_objs = [
        obj for obj in objs if int(obj.find('difficult').text) == 0]
      # if len(non_diff_objs) != len(objs):
      #     print 'Removed {} difficult objects'.format(
      #         len(objs) - len(non_diff_objs))
      objs = non_diff_objs
    num_objs = len(objs)  # 输进来的图片上的物体object的个数

    boxes = np.zeros((num_objs, 4), dtype=np.uint16)
    gt_classes = np.zeros((num_objs), dtype=np.int32)
    overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
    # "Seg" area for pascal is just the box area
    seg_areas = np.zeros((num_objs), dtype=np.float32)

    # Load object bounding boxes into a data frame.
    for ix, obj in enumerate(objs):  # 对于该图片上每一个object
      bbox = obj.find('bndbox')   
 # pascal_voc的XML文件中给出的图片gt的形式是左上角和右下角的坐标
      # Make pixel indexes 0-based
      x1 = float(bbox.find('xmin').text) - 1
      y1 = float(bbox.find('ymin').text) - 1
      x2 = float(bbox.find('xmax').text) - 1
      y2 = float(bbox.find('ymax').text) - 1
#为什么要减去1?是因为VOC的数据,坐标-1,默认坐标从0开始(这个还有待商榷,先忽略)
      cls = self._class_to_ind[obj.find('name').text.lower().strip()]
#找到该object的类别cls
      boxes[ix, :] = [x1, y1, x2, y2]
      gt_classes[ix] = cls
      overlaps[ix, cls] = 1.0
      seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
# seg_areas[ix]是该object gt的面积

    overlaps = scipy.sparse.csr_matrix(overlaps)

    return {'boxes': boxes,
            'gt_classes': gt_classes,
            'gt_overlaps': overlaps,
            'flipped': False,
            'seg_areas': seg_areas}
  • 子函数def image_path_at(self, i)
  def image_path_at(self, i):
    """
    Return the absolute path to image i in the image sequence.
    """
    return self.image_path_from_index(self._image_index[i])

根据第i个图像样本返回其对应的path,其调用了image_path_from_index(self, index)作为其具体实现。

  • 子函数def image_path_from_index(self, index)
  def image_path_from_index(self, index):
    """
    Construct an image path from the image's "index" identifier.
    """
    image_path = os.path.join(self._data_path, 'JPEGImages',
                              index + self._image_ext) 
 #这个就是图片本身所在的路径。其中index是一张图片的名字(没有后缀),_image_ext是图片后缀名.jpg。例如VOCdevkit2007/VOC2007/JPEGImages/000005.jpg
    assert os.path.exists(image_path), \
      'Path does not exist: {}'.format(image_path)
# 如果该路径不存在,退出
    return image_path

以上可见,pascal_voc.py用了较多的路径拼接

3 修改模型文件配置

  • 修改config.py
    工程tf-faster-rcnn中模型的参数都在文件tf-faster-rcnn/lib/model/config.py中被定义。
# Images to use per minibatch
__C.TRAIN.IMS_PER_BATCH = 1  #每次输入到faster-rcnn网络中的图片数量是1张

# Iterations between snapshots
__C.TRAIN.SNAPSHOT_ITERS = 5000 # 训练的时候,每5000步保存一次模型。

# solver.prototxt specifies the snapshot path prefix, this adds an optional
# infix to yield the path: [_]_iters_XYZ.caffemodel
__C.TRAIN.SNAPSHOT_PREFIX = 'res101_faster_rcnn'  #模型在保存时的名字

# Use RPN to detect objects
__C.TRAIN.HAS_RPN = True #是否使用RPN。True代表使用RPN
  • demo.py分析
CLASSES = ('__background__',
           '"seaurchin"', '"scallop"', '"seacucumber"')
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_15000.ckpt',),
'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

def vis_detections模块:画出测试图片的bounding boxes, 参数im为测试图片; class_name 为类别名称,在前面定义的 CLASSES 中; dets为非极大值抑制后的bbox和score的数组;thresh是最后score的阈值,高于该阈值的候选框才会被画出来。

def vis_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""  
    ##选取候选框score大于阈值的dets
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return
# python-opencv 中读取图片默认保存为[w,h,channel](w,h顺序不确定)
    # 其中 channel:BGR 存储,而画图时,需要按RGB格式,因此此处作转换。

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:   #从dets中取出 bbox, score
        bbox = dets[i, :4]
        score = dets[i, -1]
#  根据起始点坐标以及w,h 画出矩形框
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()

def demo模块:对测试图片提取预选框,并进行非极大值抑制,然后调用def vis_detections 画矩形框。参数:net 测试时使用的网络结构;image_name:图片名称。

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    # Visualize detections for each class
    CONF_THRESH = 0.8   
#score 阈值,最后画出候选框时需要,>thresh才会被画出
    NMS_THRESH = 0.3
 #非极大值抑制的阈值,剔除重复候选框
    for cls_ind, cls in enumerate(CLASSES[1:]):
  #利用enumerate函数,获得CLASSES中 类别的下标cls_ind和类别名cls
        cls_ind += 1 # because we skipped background
#将bbox,score 一起存入dets
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]  
# because we skipped background
        cls_scores = scores[:, cls_ind]
   #取出bbox ,score#将bbox,score 一起存入dets
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
 #进行非极大值抑制,得到抑制后的 dets
        dets = dets[keep, :]    #画框
        vis_detections(im, cls, dets, thresh=CONF_THRESH)
  • def parse_args模块:解析命令行参数,得到gpu||cpu, net等。
def parse_args():
    """Parse input arguments."""
 #创建解析对象
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    #parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
    #                   choices=NETS.keys(), default='res101')  #default
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='vgg16')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
#调用parser.parse_args进行解析,返回带标注的args
    args = parser.parse_args()

    return args
  • 主函数
if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
#解析
    args = parse_args()
#添加路径
    cfg.USE_GPU_NMS = False
    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
                              NETS[demonet][0])


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST", 21,
                          tag='default', anchor_scales=[8, 16, 32])
#用自己的数据集测试时,21根据classes类别数量修改
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))

    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']  #default
    im_names = ['000000.jpg']
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(sess, net, im_name)

plt.show()
  • 根据自己的数据集训练好模型后,要想运行Demo并将所有类别在同一图片显示,需要按照如下代码进行修改调整。
#!/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

CLASSES = ('__background__',
           '"seaurchin"', '"scallop"', '"seacucumber"')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

#增加ax参数,即第4项
def vis_detections(im, class_name, dets, ax, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

#注释原代码的以下三行
    #im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    #ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          #edgecolor='red', linewidth=3.5)
                          edgecolor='red', linewidth=1)  
                          # 矩形线宽从3.5改为1,红框变细
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
#注释原代码以下三行
    #plt.axis('off')
    #plt.tight_layout()
    #plt.draw()

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
 # 将vis_detections 函数中for 循环之前的3行代码移动到这里
    im = im[:, :, (2, 1, 0)]
    fig,ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')

    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        #将ax做为参数传入vis_detections,即增加第4项
        vis_detections(im, cls, dets,ax,thresh=CONF_THRESH)
    # 将vis_detections 函数中for 循环之后的3行代码移动到这里
    plt.axis('off')
    plt.tight_layout()
    plt.draw()



def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    #parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
    #                   choices=NETS.keys(), default='res101')  #default
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='vgg16')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()
    cfg.USE_GPU_NMS = False
    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
                              NETS[demonet][0])


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST", 4,
                          tag='default', anchor_scales=[8, 16, 32])
#net.create_architecture第2个参数是需要识别的类别+1,本例有3个待识别物体,加background共计4
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))

    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']  #default
    im_names = ['000337.jpg']   #测试的图片,保存在tf-faster-rcnn-contest/data/demo 路径
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(sess, net, im_name)

plt.show()

  • 根据自己的数据集训练好模型后,要想运行demo.py批量处理测试图片,并将所有类别在同一图片显示,需要按照如下代码进行修改调整。
#!/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

import scipy.io as sio
import os, sys, cv2
import argparse

import os
import numpy
from PIL import Image   #导入Image模块
from pylab import *     #导入savetxt模块

CLASSES = ('__background__',
           'holothurian', 'echinus', 'scallop', 'starfish')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

def vis_detections(im, class_name, dets, ax, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    #im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    #ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          #edgecolor='red', linewidth=3.5)
                          edgecolor='red', linewidth=1)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    #plt.axis('off')
    #plt.tight_layout()
    #plt.draw()

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    save_jpg = os.path.join('/data/test',im_name)

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    im = im[:, :, (2, 1, 0)]
    fig,ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')

    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]

        vis_detections(im, cls, dets,ax,thresh=CONF_THRESH)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()

def get_imlist(path):  # 此函数读取特定文件夹下的jpg格式图像
    return [os.path.join(f) for f in os.listdir(path) if f.endswith('.jpg')]



def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    #parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
    #                   choices=NETS.keys(), default='res101')  #default
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='vgg16')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()
    cfg.USE_GPU_NMS = False
    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
                              NETS[demonet][0])


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST",5,
                          tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))


    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']  #default
    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']
    im_names = get_imlist(r"/home/ouc/LiuHongzhi/tf-faster-rcnn-contest -2018/data/demo")
    print(im_names)
    for im_name in im_names:
    #path = "/home/henry/Files/URPC2018/VOC/VOC2007/JPEGImages/G0024172/*.jpg"
    #filelist = os.listdir(path)
    #for im_name in path:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(sess, net, im_name)
        plt.savefig("testfigs/" + im_name)
#plt.show()
  • 根据自己的数据集训练好模型后,要想运行demo.py批量处理测试图片,并按照 格式输出信息,需要按照如下代码进行修改调整。
#!/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import os.path
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

import scipy.io as sio
import os, sys, cv2
import argparse

import os
import numpy
from PIL import Image   #导入Image模块
from pylab import *     #导入savetxt模块

CLASSES = ('__background__',
           'holothurian', 'echinus', 'scallop', 'starfish')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

def vis_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    #im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    #ax.imshow(im, aspect='equal')

    # !/usr/bin/env python
    # -*- coding: UTF-8 -*-
    # --------------------------------------------------------
    # Faster R-CNN
    # Copyright (c) 2015 Microsoft
    # Licensed under The MIT License [see LICENSE for details]
    # Written by Ross Girshick
    # --------------------------------------------------------

    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]
        if class_name == '__background__':
            fw = open('result.txt', 'a')  # 最终的txt保存在这个路径下,下面的都改
            fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
            fw.close()

        elif class_name == 'holothurian':
               fw = open('result.txt', 'a')  # 最终的txt保存在这个路径下,下面的都改
               fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
               fw.close()


        elif class_name == 'echinus':
             fw = open('result.txt', 'a')  # 最终的txt保存在这个路径下,下面的都改
             fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
             fw.close()

        elif class_name == 'scallop':
              fw = open('result.txt', 'a')  # 最终的txt保存在这个路径下,下面的都改
              fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
              fw.close()

        elif class_name == 'starfish':
              fw = open('result.txt', 'a')  # 最终的txt保存在这个路径下,下面的都改
              fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
              fw.close()

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    save_jpg = os.path.join('/data/test',im_name)

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    #im = im[:, :, (2, 1, 0)]
    #fig,ax = plt.subplots(figsize=(12, 12))
    #ax.imshow(im, aspect='equal')

    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]

        vis_detections(im, cls, dets,thresh=CONF_THRESH)

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    #parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
    #                   choices=NETS.keys(), default='res101')  #default
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='vgg16')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()
    cfg.USE_GPU_NMS = False
    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
                              NETS[demonet][0])


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST",5,
                          tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))


    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']  #default
    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']

    im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
    for i in range(2):
        _, _= im_detect(sess,net, im)

    #im_names = get_imlist(r"/home/henry/Files/tf-faster-rcnn-contest/data/demo")
    fr = open('/home/ouc/LiuHongzhi/tf-faster-rcnn-contest -2018/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt', 'r')
    for im_name in fr:
    #path = "/home/henry/Files/URPC2018/VOC/VOC2007/JPEGImages/G0024172/*.jpg"
    #filelist = os.listdir(path)
    #for im_name in path:
        im_name = im_name.strip('\n')
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(sess, net, im_name)
#plt.show()
fr.close

输出效果如下:

000646.jpg echinus 0.9531797 617 89 785 272
000646.jpg echinus 0.94367296 200 272 396 495
000646.jpg echinus 0.9090044 953 259 1112 443
000646.jpg scallop 0.8987418 1508 975 1580 1037
000646.jpg scallop 0.8006968 512 169 580 218
000646.jpg starfish 0.96790546 291 675 390 765
001834.jpg echinus 0.9706842 291 222 365 280
001834.jpg echinus 0.965007 511 161 588 229
001834.jpg echinus 0.95911396 2 184 136 283

4 知识点补充

  • argparse
    argparse是python用于解析命令行参数和选项的标准模块,用于代替已经过时的optparse模块。argparse模块的作用是用于解析命令行参数,例如python parseTest.py input.txt output.txt –user=name –port=8080。
    使用步骤:
    1:import argparse
    2:parser = argparse.ArgumentParser()
    3:parser.add_argument()
    4:parser.parse_args()
    解释:首先导入该模块;然后创建一个解析对象;然后向该对象中添加你要关注的命令行参数和选项,每一个add_argument方法对应一个你要关注的参数或选项;最后调用parse_args()方法进行解析;

  • IoU非极大值抑制
    IoU参考

你可能感兴趣的:(Faster R-CNN tensorflow代码详解)