先感受一下怎么改:
Config 文件中的item与类的 __init__.py
参数一一对应
在 PaddleSeg 的 Config 文件中:
val_dataset:
type: Dataset
dataset_root: /root/share/program/save/data/portraint
val_path: /root/share/program/save/data/portraint/txtfiles/test.txt
num_classes: 2
transforms:
- type: Resize
target_size: [224, 224]
- type: Normalize
mode: val
上边的 Config 参数就是传入 Dataset 类的 __init__
中的参数
def __init__(self,
transforms,
dataset_root,
num_classes,
mode='train',
train_path=None,
val_path=None,
test_path=None,
separator=' ',
ignore_index=255,
edge=False):
self.dataset_root = dataset_root
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode.lower()
self.num_classes = num_classes
self.ignore_index = ignore_index
self.edge = edge
......
如果这个 Dataset 不符合你的需求,则可以自定义一类,如 XXDataset
则将 Config 文件中的 Dataset type 部分改成 你自定义的类 XXDataset
该类需要在 paddleseg/datasets/__init__.py
中导入,可以直接在 __init__.py
中定义,也可以新建一个文件,然后在 __init__.py
中导入即可.
注意,在自定义数据类的开始,要放上这个装饰器,(要这样注册一下?)
@manager.DATASETS.add_component
(啊这,我我记得之前不用啊,或者说之前 PaddleClas 不用?
举个例子:
paddleseg/datasets/dataset.py
中:
然后这里要注意一下:
如果不定义这两个属性,这里会报错,transforms 属性,按理说应该是参数传进来的,但是我这里是直接预处理的图片,所以给了 None
一般来说要传入的
所以我这里还得改一个地方,反正哪里报错,就改哪里就行
这里再补充一下,一般情况下,用传进来的 Transform 参数做预处理,不要像我这样:
class myDataset(Dataset):
def __init__(self, data_root,
img_name="images",
ann_name="annotations"):
self.img_path = os.path.join(data_root, img_name)
self.ann_path = os.path.join(data_root, ann_name)
# 务必传入相对路径,绝对路径 OpenCV 读不了
self.imgs = [os.path.join(self.img_path, img) for img in os.listdir(self.img_path)]
self.anns = [img.replace(".jpg", ".png") for img in self.imgs]
self.anns = [img.replace(img_name, ann_name) for img in self.anns]
def __len__(self):
return len(self.anns)
def __getitem__(self, idx):
img = cv2.imread(self.imgs[idx])
ann = cv2.imread(self.anns[idx], 0)
img = self.img_preprocess(img)
ann = self.ann_preprocess(ann)
return img, ann
@staticmethod
def img_preprocess(img):
img = cv2.resize(img, (224, 224)).astype(np.float32)
img = img.transpose(2, 0, 1)
img = img[None]
img /= 255
return img
@staticmethod
def ann_preprocess(ann):
ann = cv2.resize(ann, (224, 224))
return ann
我这里专门定义了 self.img_preprocess
和 self.ann_preprocess
去读取图片,并做预处理.
所以我上边的代码,会把 self.transform
给 None
,然后传入的参数给 []
这样也行,反正也懒得在 transforms 中注册一个新的 transform
但是,PaddleSeg 在 transforms 中做了一点儿偷懒的工作,会帮咱自动读入图片和Label
来看一下,到底是 transforms 中哪个类会读取图片:
在 Compose 的 __call__
函数中:
如果传入的参数是 str 则其会读入对应的图
(注意,这里读进来的图片是 0-255 的,是uint8的)
(注意,for循环下的操作,是将所有的Transform做完,然后才进行的 Transpose)
也就是说,在所有的 Transform 之后,其输出的是 CHW, 也就是通道数在最前面
另外,最后一步的转置是 np 的操作,也就是说 PaddleSeg 的 Transform 大部分是基于 numpy 实现的(只有读取图片时,会用到 PIL.Image
类)
manager.MODELS
manager.BACKBONES
manager.DATASETS
manager.TRANSFORMS
manager.LOSSES
大概,后边加上 .add_component
后即可添加新对象