数据下载和预处理是机器学习、深度学习实际项目中耗时又重要的任务,
尤其是数据预处理,关系到数据质量和模型性能,往往要占据项目的大部分时间。
PyTorch涉及数据处理(数据装载、数据预处理、数据增强等)主要工具包及相互关系如图:
主要包括两大部分:
(1)torch.utils.data相关部分
torch.utils.data工具包,它包括以下4个类函数。
1)Dataset:是一个抽象类,其他数据集需要继承这个类,并且覆写其
中的两个方法( getitem_()、len ())。
2)DataLoader:定义一个新的迭代器,实现批量(batch)读取,打乱
数据(shuffle)并提供并行加速等功能。
3)random_split:把数据集随机拆分为给定长度的非重叠的新数据集。
4)*sampler:多种采样函数
(2)torchvision工具包相关部分
它包括4个类,各类的主要功能如下。
1)datasets:提供常用的数据集加载,设计上都是继承自
torch.utils.data.Dataset,主要包括MMIST、CIFAR10/100、ImageNet和COCO
等。
2)models:提供深度学习中各种经典的网络结构以及训练好的模型
(如果选择pretrained=True),包括AlexNet、VGG系列、ResNet系列、
Inception系列等。
3)transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image
对象的操作。
4)utils:含两个函数,一个是make_grid,它能将多张图片拼接在一个
网格中;另一个是save_img,它能将Tensor保存成图片。
utils.data包括Dataset和DataLoader。
# 官方类的定义和相关说明
class Dataset(object):
"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
上述代码是pytorch中Datasets的源码,注意成员方法__getitem__和__len__都是未实现的。我们要实现自定义Datasets类来完成数据的读取,则只需要完成这两个成员方法的重写。
首先__getitem__方法用来从datasets中读取一条数据,这条数据包含训练图片(已CV距离)和标签,参数index表示图片和标签在总数据集中的Index。
其次__len__ 方法返回数据集的总长度(训练集的总数)。
Dataset只支持两种类型的数据集:map-style datasets
, iterable-style datasets
.
__getitem__一次只能获取一个数据,
所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。
data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate at 0x7f108ee01620>,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
)
函数说明
·dataset:加载的数据集。
·batch_size:批大小。
·shuffle:是否将数据打乱。
·sampler:样本抽样。
·num_workers:使用多进程加载的进程数,0代表不使用多进程。
·collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可。
·pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些。
·drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
举例说明使用过程:
1)导入需要的模块
import torch
from torch.utils import data
import numpy as np
2)定义获取数据集的类。
该类继承基类Dataset,自定义一个数据集及对应标签
class TestDataset(data.Dataset):#继承Dataset
def __init__(self):
#一些由2维向量表示的数据集
self.Data=np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])
#这是数据集对应的标签
self.Label=np.asarray([0,1,0,1,2])
def __getitem__(self, index):
#把numpy转换为Tensor
txt=torch.from_numpy(self.Data[index])
label=torch.tensor(self.Label[index])
return txt,label
def __len__(self):
return len(self.Data)
3)获取数据集中数据
Test=TestDataset()
print(Test[2]) #相当于调用__getitem__(2)
print(Test.__len__())
#输出:
#(tensor([2, 1]), tensor(0))
#5
4) 批处理
以上数据以tuple返回,每次只返回一个样本。实际上,Dateset只负责数据的抽取,调用一次__getitem__只返回一个样本。如果希望批量处理(batch),还要同时进行shuffle和并行加速等操作,可选择DataLoader。
test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=2)
for i,traindata in enumerate(test_loader):
print('i:',i)
Data,Label=traindata
print('data:',Data)
print('Label:',Label)
#输出
i: 0
data: tensor([[1, 2],[3, 4]])
Label: tensor([0, 1])
i: 1
data: tensor([[2, 1],[3, 4]])
Label: tensor([0, 1])
i: 2
data: tensor([[4, 5]])
Label: tensor([2])
从这个结果可以看出,这是批量读取。我们可以像使用迭代器一样使用它,比如对它进行循环操作。不过由于它不是迭代器,我们可以通过iter命令将其转换为迭代器
dataiter=iter(test_loader)
imgs,labels=next(dataiter)
类型1:map-style datasets
A:构建dateset类
重点是把 x 和 label 都分别装入两个列表 self.src 和 self.trg ,然后通过 getitem(self, index)返回对应元素。
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
class My_dataset(Dataset):
def __init__(self):
super().__init__()
# 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。
# 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里
self.x = torch.randn(1000,3)
self.y = self.x.sum(axis=1)
self.src, self.trg = [], []
for i in range(1000):
self.src.append(self.x[i])
self.trg.append(self.y[i])
def __getitem__(self, index):
return self.src[index], self.trg[index]
def __len__(self):
# 或者return len(self.trg), src和trg长度一样
return len(self.src)
data_train = My_dataset()
data_test = My_dataset()
data_loader_train = DataLoader(data_train, batch_size=5, shuffle=False)
data_loader_test = DataLoader(data_test, batch_size=5, shuffle=False)
# i_batch的多少根据batch size和def __len__(self)返回的长度确定
# batch_data返回的值根据def __getitem__(self, index)来确定
for i_batch, batch_data in enumerate(data_loader_train):
print(i_batch) # 打印batch编号
print(batch_data[0]) # 打印该batch里面src
print(batch_data[1]) # 打印该batch里面trg
# 对测试集:(下面的语句也可以)
for i_batch, (src, trg) in enumerate(data_loader_test):
print(i_batch) # 打印batch编号
print(src) # 打印该batch里面src的尺寸
print(trg) # 打印该batch里面trg的尺寸
输出
0
tensor([[ 0.2588, -0.0292, 1.0143],
[ 0.1215, -0.0259, -1.1979],
[ 0.2648, 1.7875, 0.3942],
[-0.7355, -0.9454, -0.1084],
[-0.1744, 0.1619, 0.5177]])
tensor([ 1.2439, -1.1023, 2.4465, -1.7893, 0.5051])
1
tensor([[ 0.6797, -0.3623, -0.2554],
[-1.0481, -0.7783, 1.8088],
[ 0.6535, 0.5184, -0.0382],
[ 2.3790, 1.8096, 0.1110],
[-0.3820, -1.5508, 0.3057]])
tensor([ 0.0619, -0.0176, 1.1337, 4.2997, -1.6271])
...
B:借助TensorDataset直接将数据包装成dataset类
另一种方法是直接使用 TensorDataset 来将数据包装成Dataset类,再使用dataloader。
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
src = torch.sin(torch.arange(1, 1000, 0.1))
trg = torch.cos(torch.arange(1, 1000, 0.1))
# 总共有9990个数据
data = TensorDataset(src, trg)
# 9990个数据被分成1998个batch,每batch有数据5个,所以data_loader的len为1998,可以从i_batch看出
data_loader = DataLoader(data, batch_size=5, shuffle=False)
for i_batch, batch_data in enumerate(data_loader):
print(i_batch) # 打印batch编号
print(batch_data[0].size()) # 打印该batch里面src
print(batch_data[1].size()) # 打印该batch里面trg
输出
0
torch.Size([5])
torch.Size([5])
1
torch.Size([5])
torch.Size([5])
2
torch.Size([5])
torch.Size([5])
...
类型2:iterable-style datasets
可迭代样式的数据集是IterableDataset的一个实例,该实例必须重写__iter__方法,该方法用于对数据集进行迭代。这种类型的数据集特别适合随机读取数据不太可能实现的情况,并且批处理大小batchsize取决于获取的数据。比如读取数据库,远程服务器或者实时日志等数据的时候,可使用该样式,一般时序数据不使用这种样式。
注意:
一般用data.Dataset处理同一个目录下的数据。如果数据在不同目录下,
因为不同的目录代表不同类别(这种情况比较普遍),使用data.Dataset来处理就很不方便。不过,使用PyTorch另一种可视化数据处理工具(即torchvision)就非常方便,不但可以自动获取标签,还提供很多数据预处理、数据增强等转换函数
https://blog.csdn.net/zuiyishihefang/article/details/105985760
https://blog.csdn.net/weixin_42468475/article/details/108714940
https://blog.csdn.net/u011995719/article/details/85102770