mmdetection里的数据增强

文章目录

  • 前言
  • LoadImageFromFile
  • LoadAnnotations
  • 1、Resize
  • 2、RandomFlip
    • 3.Normalize
  • 4. Pad
  • 5.DefaultFormatBundle
  • 6.collect
  • 最后字典里的数据


前言

记录下mmdetection配置里的简单的数据增强,里面的代码都是截取的部分便于理解的代码。

LoadImageFromFile

这个就是加载图片的信息放入results的字典里,文件mmdet/datasets/pipelines/loading.py。
刚开始的results字典文件
mmdetection里的数据增强_第1张图片

results['filename'] = filename
results['ori_filename'] = results['img_info']['filename']
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img']
return results

执行之后的
在这里插入图片描述
在这里插入图片描述
后面都一样,往字典里添加和对字典里的数据进行处理,就不放图了。

LoadAnnotations

加载图片的一些标签信息,文件mmdet/datasets/pipelines/loading.py。

if self.with_bbox:
     results = self._load_bboxes(results)
     if results is None:
         return None
 if self.with_label:
     results = self._load_labels(results)
 if self.with_mask:
     results = self._load_masks(results)
 if self.with_seg:
     results = self._load_semantic_seg(results)
 return results

1、Resize

数据增强的文件代码大部分都在mmdet/datasets/pipelines/transforms.py里,自己增加需要的数据增强方法最好也在这里增加。

self._resize_img(results)
self._resize_bboxes(results)    #这里先根据缩放因子进行相应的缩放。然后范围clip限制去除超出图像范围的部分。
self._resize_masks(results)
self._resize_seg(results)

2、RandomFlip

cur_dir = np.random.choice(direction_list, p=flip_ratio_list)   # [flip,None]随机选择是否翻转
results['flip'] = cur_dir is not None       # 
if 'flip_direction' not in results:
   results['flip_direction'] = cur_dir
if results['flip']:
   # flip image
   for key in results.get('img_fields', ['img']):
       results[key] = mmcv.imflip(
           results[key], direction=results['flip_direction'])
   # flip bboxes
   for key in results.get('bbox_fields', []):
       results[key] = self.bbox_flip(results[key],
                                     results['img_shape'],
                                     results['flip_direction'])
   # flip masks
   for key in results.get('mask_fields', []):
       results[key] = results[key].flip(results['flip_direction'])

   # flip segs
   for key in results.get('seg_fields', []):
       results[key] = mmcv.imflip(
           results[key], direction=results['flip_direction'])
return results

3.Normalize

代码如下(示例):

 for key in results.get('img_fields', ['img']):	# 是对img进行归一化
     results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
                                     self.to_rgb)
 results['img_norm_cfg'] = dict(
     mean=self.mean, std=self.std, to_rgb=self.to_rgb)
 return results

4. Pad

    def _pad_img(self, results):
        """Pad images according to ``self.size``."""
        pad_val = self.pad_val.get('img', 0)
        for key in results.get('img_fields', ['img']):	# 是对img进行pad
            if self.pad_to_square:
                max_size = max(results[key].shape[:2])
                self.size = (max_size, max_size)
            if self.size is not None:
                padded_img = mmcv.impad(
                    results[key], shape=self.size, pad_val=pad_val)
            elif self.size_divisor is not None:
                padded_img = mmcv.impad_to_multiple(
                    results[key], self.size_divisor, pad_val=pad_val)
            results[key] = padded_img
        results['pad_shape'] = padded_img.shape
        results['pad_fixed_size'] = self.size
        results['pad_size_divisor'] = self.size_divisor

5.DefaultFormatBundle

   results = self._add_default_meta_keys(results)
   if len(img.shape) < 3:
       img = np.expand_dims(img, -1)
   img = np.ascontiguousarray(img.transpose(2, 0, 1))
   results['img'] = DC(to_tensor(img), stack=True)
for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
   if key not in results:
       continue
   results[key] = DC(to_tensor(results[key]))	#将格式to_tensor

6.collect

data = {}
img_meta = {}
for key in self.meta_keys:
    img_meta[key] = results[key]	# 最后输入需要什么就将一些结果字典里的字段复制过来
data['img_metas'] = DC(img_meta, cpu_only=True)
for key in self.keys:
    data[key] = results[key]
return data		# 最后返回需要的数据

最后字典里的数据

mmdetection里的数据增强_第2张图片

你可能感兴趣的:(人工智能,机器学习,python)