[拆轮子] PaddleDetection 中的预处理 Decode

相对路径在这里 ppdet/data/transform/operators.py

在上篇 https://blog.csdn.net/HaoZiHuang/article/details/128391985

def __getitem__(self, idx):
	
	......
	    
    # ------- 对当前数据项进行之前的 transform ------- 
    return self.transform(roidb)

self.transform 进行数据预处理和数据增强,来看看最基本的数据加载部分 Decode,先来看看其基类

1. 基类 BaseOperator

class BaseOperator(object):
    def __init__(self, name=None):
        if name is None:
            name = self.__class__.__name__
        self._id = name + '_' + str(uuid.uuid4())[-6:]

    def apply(self, sample, context=None):
        """ Process a sample.
        Args:
            sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
            context (dict): info about this sample processing
        Returns:
            result (dict): a processed sample
        """
        return sample

    def __call__(self, sample, context=None):
        """ Process a sample.
        Args:
            sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
            context (dict): info about this sample processing
        Returns:
            result (dict): a processed sample
        """
        if isinstance(sample, Sequence):
            for i in range(len(sample)):
                sample[i] = self.apply(sample[i], context)
        else:
            sample = self.apply(sample, context)
        return sample

    def __str__(self):
        return str(self._id)

BaseOperator__call__ 会调用 apply 去进行输出处理,如果输入是列表,则对每一个元素都进行处理,如果只是一个元素,则只对其进行处理

__init__ 中进行 self._id 的命名

def __init__(self, name=None):
    if name is None:
        name = self.__class__.__name__
    self._id = name + '_' + str(uuid.uuid4())[-6:] # uuid 用来生成随机串

1. 子类 Decode

Decode 重写了 apply 方法, 参数 context 并未被使用

class Decode(BaseOperator):
    def __init__(self):
        """ Transform the image data to numpy format following the rgb format
        """
        super(Decode, self).__init__()

    def apply(self, sample, context=None):
        """ load image if 'im_file' field is not empty but 'image' is"""
        if 'image' not in sample:
            with open(sample['im_file'], 'rb') as f:
                sample['image'] = f.read()     # 先用二进制的方式读入
            sample.pop('im_file')              # 删除 im_file 图片路径 key-value 对

        try:
        	# ------------ 将二进制形式转化为 np.ndarray 格式 ------------
            im = sample['image']
            data = np.frombuffer(im, dtype='uint8')
            im = cv2.imdecode(data, 1)  # BGR mode, but need RGB mode
			
			# 是否保留原图
            if 'keep_ori_im' in sample and sample['keep_ori_im']:
                sample['ori_image'] = im
            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # 转化为 RGB 格式
        except:
            im = sample['image']
		
		# --------- 将图片留在 sample 中 ---------
        sample['image'] = im
		
		# --------- 检验 h 与 w 的合法性或赋值 ---------
        if 'h' not in sample:
            sample['h'] = im.shape[0]
        elif sample['h'] != im.shape[0]:
            logger.warning(
                "The actual image height: {} is not equal to the "
                "height: {} in annotation, and update sample['h'] by actual "
                "image height.".format(im.shape[0], sample['h']))
            sample['h'] = im.shape[0]
        if 'w' not in sample:
            sample['w'] = im.shape[1]
        elif sample['w'] != im.shape[1]:
            logger.warning(
                "The actual image width: {} is not equal to the "
                "width: {} in annotation, and update sample['w'] by actual "
                "image width.".format(im.shape[1], sample['w']))
            sample['w'] = im.shape[1]
		
		# 写入图片 shape 与 缩放比例
        sample['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
        sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
        return sample

最终的 sample 变量是个字典:

>>> sample.keys()
dict_keys(['im_id', 'h', 'w', 'is_crowd', 'gt_class', 'gt_bbox', 'curr_iter', 'image', 'im_shape', 'scale_factor'])
{'curr_iter': 0,
 'gt_bbox': array([[  3.27, 266.85, 404.5 , 475.1 ],
       ......
       [461.77, 253.68, 470.01, 286.99]], dtype=float32),
 'gt_class': array([[59],
       ......
       [73]], dtype=int32),
 'im_id': array([632]),
 'im_shape': array([483., 640.], dtype=float32),
 'image': array([[[100, 101,  87],
        ......
        [ 55,  62,  54]]], dtype=uint8),
 'is_crowd': array([[0],
       ......
       [0]], dtype=int32),
 'scale_factor': array([1., 1.], dtype=float32),
 'h': 483.0,
 'w': 640.0}

顺便说一下,PaddleDetection自定义使用的文档在这里:
https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/docs/advanced_tutorials/READER.md

关于类似 torch 一样层初始化的代码,可以参考PaddleDetecion的这个部分:
ppdet/modeling/initializer.py

你可能感兴趣的:(PaddleDetection,python,numpy)