[PaddleSeg 源码阅读] PaddleSeg 自定义数据类

先感受一下怎么改:

Config 文件中的item与类的 __init__.py 参数一一对应

在 PaddleSeg 的 Config 文件中:

val_dataset:
  type: Dataset
  dataset_root: /root/share/program/save/data/portraint
  val_path: /root/share/program/save/data/portraint/txtfiles/test.txt
  num_classes: 2
  transforms:
    - type: Resize
      target_size: [224, 224]
    - type: Normalize
  mode: val

上边的 Config 参数就是传入 Dataset 类的 __init__ 中的参数

    def __init__(self,
                 transforms,
                 dataset_root,
                 num_classes,
                 mode='train',
                 train_path=None,
                 val_path=None,
                 test_path=None,
                 separator=' ',
                 ignore_index=255,
                 edge=False):
        self.dataset_root = dataset_root
        self.transforms = Compose(transforms)
        self.file_list = list()
        self.mode = mode.lower()
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.edge = edge
        
		......

如果这个 Dataset 不符合你的需求,则可以自定义一类,如 XXDataset

则将 Config 文件中的 Dataset type 部分改成 你自定义的类 XXDataset

该类需要在 paddleseg/datasets/__init__.py 中导入,可以直接在 __init__.py 中定义,也可以新建一个文件,然后在 __init__.py 中导入即可.

注意,在自定义数据类的开始,要放上这个装饰器,(要这样注册一下?)

@manager.DATASETS.add_component

(啊这,我我记得之前不用啊,或者说之前 PaddleClas 不用?

举个例子:

paddleseg/datasets/dataset.py 中:
[PaddleSeg 源码阅读] PaddleSeg 自定义数据类_第1张图片

config 文件:
[PaddleSeg 源码阅读] PaddleSeg 自定义数据类_第2张图片

__init__.py 中导入:
[PaddleSeg 源码阅读] PaddleSeg 自定义数据类_第3张图片

然后这里要注意一下:
[PaddleSeg 源码阅读] PaddleSeg 自定义数据类_第4张图片
如果不定义这两个属性,这里会报错,transforms 属性,按理说应该是参数传进来的,但是我这里是直接预处理的图片,所以给了 None 一般来说要传入的

所以我这里还得改一个地方,反正哪里报错,就改哪里就行

[PaddleSeg 源码阅读] PaddleSeg 自定义数据类_第5张图片

这里再补充一下,一般情况下,用传进来的 Transform 参数做预处理,不要像我这样:

class myDataset(Dataset):
    def __init__(self, data_root, 
                        img_name="images", 
                        ann_name="annotations"):
        
        self.img_path = os.path.join(data_root, img_name)
        self.ann_path = os.path.join(data_root, ann_name)
        
        # 务必传入相对路径,绝对路径 OpenCV 读不了
        self.imgs = [os.path.join(self.img_path, img) for img in os.listdir(self.img_path)]
        
        self.anns = [img.replace(".jpg", ".png") for img in self.imgs]
        self.anns = [img.replace(img_name, ann_name) for img in self.anns]
        

    def __len__(self):
        return len(self.anns)

    def __getitem__(self, idx):
        img = cv2.imread(self.imgs[idx])
        ann = cv2.imread(self.anns[idx], 0)
        
        img = self.img_preprocess(img)
        ann = self.ann_preprocess(ann)
        
        return img, ann
    
    @staticmethod
    def img_preprocess(img):
        img = cv2.resize(img, (224, 224)).astype(np.float32)
        img = img.transpose(2, 0, 1)
        img = img[None]
        img /= 255
        return img
    
    @staticmethod
    def ann_preprocess(ann):
        ann = cv2.resize(ann, (224, 224))
        return ann

我这里专门定义了 self.img_preprocessself.ann_preprocess 去读取图片,并做预处理.

所以我上边的代码,会把 self.transformNone,然后传入的参数给 []

这样也行,反正也懒得在 transforms 中注册一个新的 transform

但是,PaddleSeg 在 transforms 中做了一点儿偷懒的工作,会帮咱自动读入图片和Label

[PaddleSeg 源码阅读] PaddleSeg 自定义数据类_第6张图片
来看一下,到底是 transforms 中哪个类会读取图片:

在 Compose 的 __call__ 函数中:
[PaddleSeg 源码阅读] PaddleSeg 自定义数据类_第7张图片
如果传入的参数是 str 则其会读入对应的图

(注意,这里读进来的图片是 0-255 的,是uint8的)

(注意,for循环下的操作,是将所有的Transform做完,然后才进行的 Transpose)
也就是说,在所有的 Transform 之后,其输出的是 CHW, 也就是通道数在最前面
另外,最后一步的转置是 np 的操作,也就是说 PaddleSeg 的 Transform 大部分是基于 numpy 实现的(只有读取图片时,会用到 PIL.Image 类)


最后看看到底有哪些装饰器:
[PaddleSeg 源码阅读] PaddleSeg 自定义数据类_第8张图片

manager.MODELS
manager.BACKBONES
manager.DATASETS
manager.TRANSFORMS
manager.LOSSES

大概,后边加上 .add_component 后即可添加新对象

你可能感兴趣的:(每日一氵,paddlepaddle历险记,python,机器学习,开发语言)