gluon学习系列(三)—Dataset类和DataLoader类

龟龟是最可爱的小猫咪

Gluon使用dataset和dataloader来进行数据的加载和遍历

Dataset类

mxnet.gluon.data是gluon的数据加载API

mxnet.data.vision中则是一个和视觉相关的模块

mxnet.gluon.data 下面有已下几个类:

gluon.data.Dataset(object) # 抽象的数据类
gluon.data.ArrayDataset(Dataset) # 组合的shujulei
mx.gluon.data.RecordFileDataset(Dataset)# rec数据类

所有的Dataset类中都有以下几个方法:

- __genitem__(self, idx)           #返回第idx个样本

- transform()                      # 数据变化函数,常用于数据增广

- transform_first(fn, lazy=True)    # 仅变换dataset而不对label进行数据增广

- __len__(self)                     # 返回样本数量

Dataset类的关键就是能够根据下标获取对应样本。

在使用中用于加载,变换数据

mxnet.gluon.data.vision 下面有以下几个CV中常用数据类:

- MNIST(dataset._DownloadedDataset)

- FashionMNIST(MNIST)

- CIFAR10(dataset._DownloadedDataset)

- CIFAR100(CIFAR10)

- ImageRecordDataset(dataset.RecordFileDataset) # 从rec格式加载数据集

- ImageFolderDataser(dataset.Dataset)  # 从文件夹加载

DataLoader类

DataLoader类的核心功能:

  • 用于加载Dataset,返回batch大小的样本

  • 并行加载数据

transforms模块

在gluon.data.vision下面

使用时

from mxnet.gluon.data.vision import transforms

定义了许多数据变换的子类,这些子类继承自Block类或HybridBlock类


gluon.data.vision.transforms.Compose  #可以将这些变换按照顺序连接起来

gluon.data.vision.transforms.ToTensor # 将输入图像由HWC维度转换为CHW维度


"""Converts an image NDArray or batch of image NDArray to a tensor NDArray.

    Converts an image NDArray of shape (H x W x C) in the range
    [0, 255] to a float32 tensor NDArray of shape (C x H x W) in
    the range [0, 1].

    If batch input, converts a batch image NDArray of shape (N x H x W x C) in the
    range [0, 255] to a float32 tensor NDArray of shape (N x C x H x W).

    Inputs:
        - **data**: input tensor with (H x W x C) or (N x H x W x C) shape and uint8 type.

    Outputs:
        - **out**: output tensor with (C x H x W) or (N x H x W x C) shape and float32 type.
"""
gluon.data.vision.transforms.Normalize # 将数据按照均值和标准差标准化

Dataset类, DataLoader类以及transforms模块使用示例

jitter_param = 0.4
lighting_param = 0.1

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(256),
    transforms.RandomFlipLeftRight(),
    transforms.RandomBrightness(0.1),
    transforms.RandomFlipTopBottom(),
    transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                 saturation=jitter_param),
    transforms.RandomLighting(lighting_param),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_data = gluon.data.DataLoader(
    gluon.data.vision.ImageFolderDataset(train_path).transform_first(transform_train),
    batch_size=batch_size, shuffle=True, num_workers=num_workers)

你可能感兴趣的:(mxnet,python)