tf-faster-rcnn代码阅读:datasets


class pascal_voc(imdb):
  def __init__(self, image_set, year, use_diff=False):
    name = 'voc_' + year + '_' + image_set
    if use_diff:
      name += '_diff'
    imdb.__init__(self, name)
    self._year = year
    self._image_set = image_set
    self._devkit_path = self._get_default_path() # /home/zhaowangbo/study/tf-faster-rcnn/data/VOCdevkit2007
    #这里得到存放voc文件夹的路径
    self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year) # /home/zhaowangbo/study/tf-faster-rcnn/data/VOCdevkit2007/VOC2007
    #这里进入上面的文件夹
    self._classes = ('__background__',  # always index 0
                     'aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor')
    
    # self.num_classes comes from class imdb,it is the len of self._classes
    self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))
    #self.num_classes来自于父类imdb,返回self._classes的长度
    #self._class_to_ind得到一个字典,类别和他们的编号[('__background__', 0), ('aeroplane', 1),....]
    #print(list(zip(self.classes, list(range(self.num_classes)))))
    self._image_ext = '.jpg'
    self._image_index = self._load_image_set_index() #1
    # Default to roidb handler
    self._roidb_handler = self.gt_roidb #2
    self._salt = str(uuid.uuid4())
    self._comp_id = 'comp4'

    # PASCAL specific config options
    self.config = {'cleanup': True,
                   'use_salt': True,
                   'use_diff': use_diff,
                   'matlab_eval': False,
                   'rpn_file': None}
    #if the path does not exist, error
    assert os.path.exists(self._devkit_path), \
      'VOCdevkit path does not exist: {}'.format(self._devkit_path)
    assert os.path.exists(self._data_path), \
      'Path does not exist: {}'.format(self._data_path)

1.得到存放突变名称的txt文件


  def _load_image_set_index(self):
    """
    Load the indexes listed in this dataset's image set file.
    """
    # dao ru tuxiang de mingcheng(ye jiu shi biaohao)
    # Example path to image set file:
    # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/trainval.txt
    image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                  self._image_set + '.txt')
    assert os.path.exists(image_set_file), \
      'Path does not exist: {}'.format(image_set_file)
    with open(image_set_file) as f:
      image_index = [x.strip() for x in f.readlines()]
    return image_index

2.得到ground-truth的框

  def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """

    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') # /home/zhaowangbo/study/tf-faster-rcnn/data/cache/voc_2007_trainval_gt_roidb.pkl
    # 读取xml文件中的box的位置并保存到cache_file中,因此只用读取一次后面的就可以直接使用。
    # if true则直接读取。 if false 则使用self._load_pascal_annotation函数进行读取,然后保存
    if os.path.exists(cache_file):
      with open(cache_file, 'rb') as fid:
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb

    gt_roidb = [self._load_pascal_annotation(index)
                for index in self.image_index] #3
      # self._load_pascal_annotation函数从xml文件中读取框的坐标
    with open(cache_file, 'wb') as fid:
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb
def _load_pascal_annotation(self, index):
    """
    Load image and bounding boxes info from XML file in the PASCAL VOC
    format.
    """
    # process image one by one
    filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
    tree = ET.parse(filename)
    objs = tree.findall('object')
    # 这里设置了不使用困难的图片
    if not self.config['use_diff']:
      # Exclude the samples labeled as difficult
      non_diff_objs = [
        obj for obj in objs if int(obj.find('difficult').text) == 0]
      # if len(non_diff_objs) != len(objs):
      #     print 'Removed {} difficult objects'.format(
      #         len(objs) - len(non_diff_objs))
      objs = non_diff_objs
    num_objs = len(objs)

    boxes = np.zeros((num_objs, 4), dtype=np.uint16)
    gt_classes = np.zeros((num_objs), dtype=np.int32)
    overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
    # "Seg" area for pascal is just the box area
    seg_areas = np.zeros((num_objs), dtype=np.float32)

    # Load object bounding boxes into a data frame.
    for ix, obj in enumerate(objs):
      bbox = obj.find('bndbox')
      # Make pixel indexes 0-based
      x1 = float(bbox.find('xmin').text) - 1
      y1 = float(bbox.find('ymin').text) - 1
      x2 = float(bbox.find('xmax').text) - 1
      y2 = float(bbox.find('ymax').text) - 1
      # print(obj.find('name').text.lower().strip()) chair,horse,car....
      cls = self._class_to_ind[obj.find('name').text.lower().strip()] #find the number of this class
      #for exexmple obj.find('name').text.lower().strip() ==cahir ,cls =9

      boxes[ix, :] = [x1, y1, x2, y2]
      gt_classes[ix] = cls
      print(gt_classes)
      overlaps[ix, cls] = 1.0
      seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

    overlaps = scipy.sparse.csr_matrix(overlaps)

    return {'boxes': boxes,
            'gt_classes': gt_classes,
            'gt_overlaps': overlaps,
            'flipped': False,
            'seg_areas': seg_areas}

你可能感兴趣的:(深度学习)