Paddle:加载自定义数据集

Paddle提供两种方式来加载数据集:

1:加载内置数据

2:加载自定义数据

1:加载内置数据

飞桨框架在 paddle.vision.datasets 和 paddle.text 目录下内置了一些经典数据集可直接调用,通过以下代码可查看飞桨框架中的内置数据集。

import paddle
print('计算机视觉(CV)相关数据集:', paddle.vision.datasets.__all__)
print('自然语言处理(NLP)相关数据集:', paddle.text.__all__)

具体使用可参考官方文档:数据集定义与加载-使用文档-PaddlePaddle深度学习平台

2:加载自定义数据 

在实际的场景中,一般需要使用自有的数据来定义数据集,这时可以通过 paddle.io.Dataset 基类来实现自定义数据集。

可构建一个子类继承自 paddle.io.Dataset ,并且实现下面的三个函数:

(是不是很眼熟,不能说和Pytorch完全相同,只能说是一模一样。目的是降低迁移学习的难度)

1、__init__:完成数据集初始化操作,将磁盘中的样本文件路径和对应标签映射到一个列表中。

2、__getitem__:定义指定索引(index)时如何获取样本数据,最终返回对应 index 的单条数据(样本数据、对应的标签)。

3、__len__:返回数据集的样本总数。

直接上示例代码:

import os
import cv2
import numpy as np
from paddle.io import Dataset
from paddle.vision import transforms as T


'''
paddle-API文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html
'''

class ListDataset(Dataset):
    def __init__(self, list_file, mode='train'):
        if mode == 'train':
            print("Loading train data ......")
        else:
            print("Loading test data ......")
        # mode
        self.mode = mode
        # load list
        self.data_list = []
        with open(list_file, "r") as f:
            self.data_list = f.readlines()
        # define img transform
        self.transform_train = T.Compose([
            T.Resize((128, 64), interpolation='nearest'),
            T.ContrastTransform(0.2),
            T.BrightnessTransform(0.2),
            T.RandomHorizontalFlip(0.5),
            T.RandomRotation(15),
            T.Transpose(),
            T.Normalize(mean=[127.5, 127.5, 127.5],  data_format='CHW', std=[127.5, 127.5, 127.5],  to_rgb=True)])
        self.transfrom_eval = T.Compose([
            T.Resize((128, 64), interpolation='nearest'),
            T.Transpose(),
            T.Normalize(mean=[127.5, 127.5, 127.5],  data_format='CHW', std=[127.5, 127.5, 127.5],  to_rgb=True)])

    def __getitem__(self, index):
        line_info = self.data_list[index].strip().split(' ')
        img_bgr = cv2.imread(line_info[0])
        img_label = [int(i) for i in line_info[1:]]
        if self.mode == 'train':
            img = self.transform_train(img_bgr)
        else:
            img = self.transfrom_eval(img_bgr)

        return img, img_label

    def __len__(self):
        return len(self.data_list)

对于遇到不清楚的API:直接翻官方文档。如果还不清楚,那就翻对应的pytorch文档。两个基本是相同的。

paddle-API文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html

你可能感兴趣的:(Paddle,paddle,python,人工智能)