常用的gluon中的关于data的一些操作。本文总结一番,用于学习和理解。如果写的不错,请点个赞哦~
————————
导入
import mxnet as mx
from mxnet import gluon
gluon.data
下的所有模块
dir(gluon.data)
['ArrayDataset', 'BatchSampler', 'DataLoader', 'Dataset', 'FilterSampler', 'RandomSampler','RecordFileDataset', 'Sampler', 'SequentialSampler', 'SimpleDataset', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'dataloader', 'dataset', 'sampler', 'vision']
__getitem__(idx)
:数据加载,用于返回第idx个样本
__len__()
:用于返回数据集的样本的数量
transform(fn, lazy = True)
:数据变换,用于返回对每个样本利用fn函数进行数据变换(增广)后的Dataset
transform_first(fn, lazy = True)
:数据变换,用于返回对每个样本的特征利用fn函数进行数据变换(增广)后的Dataset,而不对label进行数据增广
sample(self, sampler)
返回带有采样器采样的元素的新数据集。
take(self, count)
返回一个新的数据集,其中最多包含“ count”个样本。
filter(self, fn)
返回一个新的数据集,其样本由过滤器函数“ fn”过滤。
shard(self, num_shards, index)
返回一个新数据集,仅包含该数据集的1 / num_shards。
参数:
count:int或None
一个整数,表示应构成新数据集的该数据集的元素数。 如果count为None,或者count大于此数据集的大小,则新数据集将包含该数据集的所有元素。
MNIST就是著名的手写数字识别库,其中包含0至9等10个数字的手写体,图片大小为28*28的灰度图,目标是根据图片识别正确的数字。
MNIST库在MXNet中被封装为MNIST类,数据存储于.mxnet/datasets/mnist中。如果下载MNIST数据较慢,可以选择到MNIST官网下载,放入mnist文件夹中即可。在MNIST类中:
- 参数train:是否为训练数据,其中true是训练数据,false是测试数据;
- 参数transform:数据的转换函数,lambda表达式,转换数据和标签为指定的数据类型;
class mxnet.gluon.data.vision.datasets.MNIST(root ='〜/ .mxnet / datasets / mnist',train = True,transform = None )
参数:
- root(str ,默认’〜/ .mxnet / datasets / mnist’) - 用于存储数据的临时文件夹路径。
- train(bool ,默认为True) - 是否加载训练或测试集。
- transform(函数,默认无) - 用户定义的回调函数,用于转换每个样本。
mxnet.gluon.data.dataset
“(x1 [i],x2 [i],...)”
。ArrayDataset(*args)
将几个数据集合并起来构建模块中用的到的数据集。
import mxnet as mx
## 定义
mx.random.seed(42) # 固定随机数种子,以便能够复现
X = mx.random.uniform(shape = (10, 3))
y = mx.random.uniform(shape = (10, 1))
dataset = mx.gluon.data.ArrayDataset(X, y) # ArrayDataset不需要从硬盘上加载数据
## 使用
dataset[5] # 将返回第6个样本的特征和标签,(特征,标签)是一个元组的方式
out:
(
[0.24848557 0.2555806 0.11790147]
,
[0.8438072]
)
解释:X,y
分别是shape为(10,3)和(10,1)的2维ndarray。
print("X[5]:\n",X[5],'\n y[5]:\n',y[5])
out:
X[5]:
[0.24848557 0.2555806 0.11790147]
y[5]:
[0.8438072]
__getitem__(self, idx)
# 返回数据集中索引为5的元素《==》dataset[5]
dataset.__getitem__(5)
__init__(self, *args)
__len__(self)
# 返回数据集中样本的个数,[每个样本是一个元组]
dataset.__len__()
Dataset
继承的方法略
DataLoader(dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, num_workers=0, pin_memory=False, pin_device_id=0, prefetch=None, thread_pool=False, timeout=120)
参数理解
dataset
(数据集) - 源数据集。请注意,numpy和mxnet数组可以直接用作数据集。batch_size
(int) - 最小批量的大小。shuffle
(bool) - 是否洗牌。sample
(采样器) - 要使用的采样器。指定采样器或混洗,而不是两者。last_batch
({‘keep’ ,‘discard’ ,‘rollover’}) -
如果batch_size不能均匀分配len(数据集),如何处理最后一批 。
keep
- 返回比前一批次样品少的批次。discard
- 如果最后一批不完整,则丢弃最后一批。rollover
- 剩余的样本将转入下一个时间段。batch_sampler
(采样器) - 返回小批量的采样器。如果指定了batch_sampler,则不要指定batch_size,shuffle,sampler和last_batch。
*batchify_fn
(可调用) -
回调函数允许用户指定如何将样本合并到批处理中。默认为 default_batchify_fnum_workers
(int ,默认值为0) - 用于数据预处理的多处理工作器的数量。 Windows尚未支持num_workers> 0。
接上面代码的数据集
data_loader = mx.gluon.data.DataLoader(dataset, batch_size = 5) # 返回一个迭代器
for X, y in data_loader:
print(X,y)
print(X.shape, y.shape)
out:
[[0.82909364 0.7478349 0.55818135]
[0.5657354 0.63112146 0.9964986 ]
[0.7411964 0.54220843 0.09796188]
[0.37050414 0.06104451 0.2780535 ]
[0.74707687 0.37641123 0.46362457]]
[[0.22065327]
[0.98274744]
[0.00898287]
[0.6343119 ]
[0.35440788]]
(5, 3) (5, 1)
[[0.24848557 0.2555806 0.11790147]
[0.5224151 0.83817047 0.73799384]
[0.02669165 0.41836238 0.24940562]
[0.30375558 0.54735136 0.35234648]
[0.65174425 0.7668355 0.18985268]]
[[0.8438072 ]
[0.05371499]
[0.0330309 ]
[0.4326624 ]
[0.30392623]]
(5, 3) (5, 1)
分析:
DataLoader
的数据源为dataset
,在上面已经知道长度为10,在生成批量数据过程中,batche_size=5,所以,会生成两个batch批次。所以, 循环两次。每次大小为5.
__del__(self)
|
__init__()
__init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, num_workers=0, pin_memory=False, pin_device_id=0, prefetch=None, thread_pool=False, timeout=120)
Initialize self. See help(type(self)) for accurate signature.
__iter__(self)
__len__(self)
.rec
文件的数据集类RecordFileDataset(filename)
参数:
- filename : str
rec file 的路径
__getitem__(self, idx)
__init__(self, filename)
Initialize self. See help(type(self)) for accurate signature.
__len__(self)
Dataset
的函数dir(gluon.data.vision)
['CIFAR10',
'CIFAR100',
'FashionMNIST',
'ImageFolderDataset',`` 'ImageRecordDataset',`` 'MNIST',`` '__builtins__',
'__cached__',
'__doc__',
'__file__',
'__loader__',`` '__name__',
'__package__',
'__path__',
'__spec__',
'datasets',
'transforms']
MNIST数据集的Dataset
:FashionMNIST数据集的Dataset
CIFAR10数据集的Dataset
CIFAR100数据集的Dataset
含有图片的.rec文件的Dataset
import mxnet as mx
## 定义
dataset= gdata.vision.ImageFolderDataset("样本集的根路径", flag=1)
## 使用
dataset[5] # 将返回第6个样本的特征和标签,(特征,标签)
存储图片在文件夹结构的Dataset
import mxnet as mx
## 定义
file = '/xxx/train.rec'
#不需要指定idx文件路径,会从路径中自动拼接处idx的路径,例如此处为/xxx/train.idx
dataset= gdata.vision.ImageRecordDataset(file)
## 使用
dataset[5] # 将返回第6个样本的特征和标签,(特征,标签)
CLASSES 类们
- mxnet.gluon.block.Block(builtins.object)
CenterCrop
RandomResizedCrop- mxnet.gluon.block.HybridBlock(mxnet.gluon.block.Block)
Cast
CropResize
Normalize
RandomBrightness
RandomColorJitter
RandomContrast
RandomFlipLeftRight
RandomFlipTopBottom
RandomHue
RandomLighting
RandomSaturation
Resize
ToTensor- mxnet.gluon.nn.basic_layers.Sequential(mxnet.gluon.block.Block)
Compose
来源参考:https://blog.csdn.net/sd__dreamer/article/details/82596332
from mxnet.gluon import data as gdata
train_ds = gdata.vision.ImageFolderDataset("样本集的根路径", flag=1)
print train_ds[0] #变换之前的数据
## 数据变换定义 # 将一系列的变换组合起来
transform_train = gdata.vision.transforms.Compose([ # Compose将这些变换按照顺序连接起来
# 将图片放大成高和宽各为 40 像素的正方形。
gdata.vision.transforms.Resize(40),
# 随机对高和宽各为 40 像素的正方形图片裁剪出面积为原图片面积 0.64 到 1 倍之间的小正方
# 形,再放缩为高和宽各为 32 像素的正方形。
gdata.vision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
ratio=(1.0, 1.0)),
# 随机左右翻转图片。
gdata.vision.transforms.RandomFlipLeftRight(),
# 将图片像素值按比例缩小到 0 和 1 之间,并将数据格式从“高 * 宽 * 通道”改为“通道 * 高 * 宽”。
gdata.vision.transforms.ToTensor(),
# 对图片的每个通道做标准化。
gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])
])
train_ds_transformed = train_ds.transform_first(train_ds )
print ( train_ds_transformed[0] )#变换之后的数据
Cast
: 变换数据类型ToTensor
: 将图像数组由“高 * 宽 * 通道”改为 “通道 * 高 * 宽”Normalize
: 对图片(shape为通道 * 高 * 宽)每个通道上的每个像素按照均值和方差标准化RandomResizedCrop
: 首先按照一定的比例随机裁剪图像,然后再对图像变换高和宽Resize
: 将图像变换高和宽RandomFlipLeftRight
: 随机左右翻转# coding: utf-8
from mxnet.gluon import data as gdata
import multiprocessing
import os
def get_cifar10(root_dir, batch_size, num_workers = 1):
train_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'train'), flag=1)
valid_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'valid'), flag=1)
train_valid_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'train_valid'), flag=1)
test_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'test'), flag=1)
transform_train = gdata.vision.transforms.Compose([
# 将图片放大成高和宽各为 40 像素的正方形。
gdata.vision.transforms.Resize(40),
# 随机对高和宽各为 40 像素的正方形图片裁剪出面积为原图片面积 0.64 到 1 倍之间的小正方
# 形,再放缩为高和宽各为 32 像素的正方形。
gdata.vision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
ratio=(1.0, 1.0)),
# 随机左右翻转图片。
gdata.vision.transforms.RandomFlipLeftRight(),
# 将图片像素值按比例缩小到 0 和 1 之间,并将数据格式从“高 * 宽 * 通道”改为
# “通道 * 高 * 宽”。
gdata.vision.transforms.ToTensor(),
# 对图片的每个通道做标准化。
gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])
])
# 测试时,无需对图像做标准化以外的增强数据处理。
transform_test = gdata.vision.transforms.Compose([
gdata.vision.transforms.ToTensor(),
gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])
])
train_ds = train_ds.transform_first(transform_train)
valid_ds = valid_ds.transform_first(transform_test)
train_valid_ds = train_valid_ds.transform_first(transform_train)
test_ds = test_ds.transform_first(transform_test)
train_data = gdata.DataLoader(train_ds, batch_size, shuffle=True, last_batch='keep',num_workers = num_workers)
valid_data = gdata.DataLoader(valid_ds, batch_size, shuffle=False, last_batch='keep', num_workers = num_workers)
train_valid_data = gdata.DataLoader(train_valid_ds, batch_size, shuffle=True, last_batch='keep', num_workers=num_workers)
test_data = gdata.DataLoader(test_ds, batch_size, shuffle=False, last_batch='keep', num_workers=num_workers)
return train_data, valid_data, train_valid_data, test_data
if __name__ == '__main__':
batch_size = 256
root_dir = '/home/face/common/samples/cifar-10/train_valid_test'
train_data, valid_data, train_valid_data, test_data = get_cifar10(root_dir, batch_size)
for batch in train_data:
data, label = batch
print data.shape, label