MindSpore图像读取和预处理教程

本教程包含MindSpore图像数据读取和一些基本数据集操作。

MindSpore 加载图像数据集

MindSpore加载图像数据集常用API:mindspore.dataset.ImageFolderDataset

数据集需要的分布格式为每个类别一个一个文件夹,文件夹内部结构如图,每个文件夹的名称会默认作为数据集的label:
MindSpore图像读取和预处理教程_第1张图片

加载

读取代码如下,执行结束后会生成一个Dataset对象:

import mindspore.dataset as ds

dataset_dir = "C:\\datasets\\caltech_for_user\\train"
dataset = ds.ImageFolderDataset(dataset_dir, decode=True)

MindSpore图像预处理

预处理

MindSpore图像预处理统一使用mindspore.dataset.vision.c_transforms模块,其中包含变形(Resize)、标准化(Normalize)、转置(HWC2HCW)等所有图像相关的预处理操作。

import mindspore.dataset.vision.c_transforms as c_transforms

image_size = 32
mean = [0.5 * 255] * 3
std = [0.5 * 255] * 3

trans = [
    c_transforms.Resize((image_size, image_size)),
    c_transforms.Normalize(mean=mean, std=std),
    c_transforms.HWC2CHW() 
]

dataset = dataset.map(operations=trans, num_parallel_workers=1)

验证集分割

train, val = dataset.split([0.8, 0.2])

batch

如果需要使用mini-batch训练,需要使用如下代码对数据集进行处理:

batch_size = 128
train = train.batch(batch_size, drop_remainder=True)

打印图像数据集信息

MindSpore数据集使用create_dict_iterator()生成一个可迭代对象,然后使用next得到每一个样本,其中mindspore.dataset.ImageFolderDataset读取默认的图像的关键字为image,标签为label:

for i in range(5):
    data = next(train.create_dict_iterator())
    print(data['label'])
    print(data['image'].shape)

你可能感兴趣的:(MindSpore教程大全,深度学习,计算机视觉)