Paddle提供两种方式来加载数据集:
1:加载内置数据
2:加载自定义数据
1:加载内置数据
飞桨框架在 paddle.vision.datasets 和 paddle.text 目录下内置了一些经典数据集可直接调用,通过以下代码可查看飞桨框架中的内置数据集。
import paddle
print('计算机视觉(CV)相关数据集:', paddle.vision.datasets.__all__)
print('自然语言处理(NLP)相关数据集:', paddle.text.__all__)
具体使用可参考官方文档:数据集定义与加载-使用文档-PaddlePaddle深度学习平台
2:加载自定义数据
在实际的场景中,一般需要使用自有的数据来定义数据集,这时可以通过 paddle.io.Dataset 基类来实现自定义数据集。
可构建一个子类继承自 paddle.io.Dataset
,并且实现下面的三个函数:
(是不是很眼熟,不能说和Pytorch完全相同,只能说是一模一样。目的是降低迁移学习的难度)
1、__init__:完成数据集初始化操作,将磁盘中的样本文件路径和对应标签映射到一个列表中。
2、__getitem__:定义指定索引(index)时如何获取样本数据,最终返回对应 index 的单条数据(样本数据、对应的标签)。
3、__len__:返回数据集的样本总数。
直接上示例代码:
import os
import cv2
import numpy as np
from paddle.io import Dataset
from paddle.vision import transforms as T
'''
paddle-API文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html
'''
class ListDataset(Dataset):
def __init__(self, list_file, mode='train'):
if mode == 'train':
print("Loading train data ......")
else:
print("Loading test data ......")
# mode
self.mode = mode
# load list
self.data_list = []
with open(list_file, "r") as f:
self.data_list = f.readlines()
# define img transform
self.transform_train = T.Compose([
T.Resize((128, 64), interpolation='nearest'),
T.ContrastTransform(0.2),
T.BrightnessTransform(0.2),
T.RandomHorizontalFlip(0.5),
T.RandomRotation(15),
T.Transpose(),
T.Normalize(mean=[127.5, 127.5, 127.5], data_format='CHW', std=[127.5, 127.5, 127.5], to_rgb=True)])
self.transfrom_eval = T.Compose([
T.Resize((128, 64), interpolation='nearest'),
T.Transpose(),
T.Normalize(mean=[127.5, 127.5, 127.5], data_format='CHW', std=[127.5, 127.5, 127.5], to_rgb=True)])
def __getitem__(self, index):
line_info = self.data_list[index].strip().split(' ')
img_bgr = cv2.imread(line_info[0])
img_label = [int(i) for i in line_info[1:]]
if self.mode == 'train':
img = self.transform_train(img_bgr)
else:
img = self.transfrom_eval(img_bgr)
return img, img_label
def __len__(self):
return len(self.data_list)
对于遇到不清楚的API:直接翻官方文档。如果还不清楚,那就翻对应的pytorch文档。两个基本是相同的。
paddle-API文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html