使用Dataset创建数据集¶
Dataset创建数据集常用的方法有:
1、使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。
2、使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。
3、继承 torch.utils.data.Dataset 创建自定义数据集。
此外,还可以通过
torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。
调用Dataset的加法运算符(+)将多个数据集合并成一个数据集。
Dataset:从numpy和DataFrame创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。
PyTorch 的 TensorDataset 是一个数据集包装张量。 通过定义索引的长度和方式,这也为我们提供了沿张量的一维进行迭代,索引和切片的方法。 这将使我们在训练的同一行中更容易访问自变量和因变量。
import numpy as np
import torch
from torch.utils.data import TensorDataset,Dataset,DataLoader,random_split
# 根据Tensor创建数据集
from sklearn import datasets
iris = datasets.load_iris()
ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))
# 分割成训练集和预测集
n_train = int(len(ds_iris)*0.8)
n_valid = len(ds_iris) - n_train
ds_train,ds_valid = random_split(ds_iris,[n_train,n_valid])
print(type(ds_iris))
print(type(ds_train))
输出:
<class 'torch.utils.data.dataset.TensorDataset'>
<class 'torch.utils.data.dataset.Subset'>
#使用DataLoader加载数据集
dl_train,dl_valid = DataLoader(ds_train,batch_size= 8),DataLoader(ds_valid,batch_size=8)
for feature,label in dl_train:
print(feature,label)
break
输出:
tensor([[7.7000, 2.6000, 6.9000, 2.3000],
[6.1000, 2.8000, 4.0000, 1.3000],
[6.3000, 2.8000, 5.1000, 1.5000],
[5.5000, 2.5000, 4.0000, 1.3000],
[5.6000, 2.9000, 3.6000, 1.3000],
[6.7000, 3.3000, 5.7000, 2.1000],
[6.1000, 2.6000, 5.6000, 1.4000],
[6.7000, 3.1000, 5.6000, 2.4000]], dtype=torch.float64) tensor([2, 1, 2, 1, 1, 2, 2, 2], dtype=torch.int32)
# 使用 + 合并数据集
ds_data = ds_train + ds_valid
print('len(ds_data)',len(ds_data))
print('len(ds_train)',len(ds_train))
print('len(ds_valid)',len(ds_valid))
print(type(ds_data))
输出:
len(ds_data) 150
len(ds_train) 120
len(ds_valid) 30
<class 'torch.utils.data.dataset.ConcatDataset'>
DataLoadr:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False)
1、dataset:(数据类型 dataset)
输入的数据集,必须为Dataset类型。
2、batch_size:(数据类型 int)
每次输入数据的行数,默认为1。
3、shuffle:(数据类型 bool)
默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。
4、collate_fn:(数据类型 callable,没见过的类型)
将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)
5、batch_sampler:(数据类型 Sampler)
批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。
6、sampler:(数据类型 Sampler)
采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则shuffle设置必须为False。
7、num_workers:(数据类型 Int)
工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。
8、pin_memory:(数据类型 bool)
内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。
9、drop_last:(数据类型 bool)
丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
10、timeout:(数据类型 numeric)
超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。
11、worker_init_fn(数据类型 callable,没见过的类型)
子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
#定义图像增强操作
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),#随机水平翻转
transforms.RandomVerticalFlip(),#随机垂直翻转
transforms.RandomRotation(45),#随机在45度内翻转
transforms.ToTensor() #转变成张量
])
transform_valid = transforms.Compose([
transforms.ToTensor()
])
# 根据图片目录创建数据集
ds_train = datasets.ImageFolder("./data/cifar2/train/",
transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("./data/cifar2/test/",
transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
print(ds_train.class_to_idx)
输出:
{'0_airplane': 0, '1_automobile': 1}
dl_train = DataLoader(ds_train,batch_size= 50,shuffle=True)
dl_valid = DataLoader(ds_valid,batch_size= 50, shuffle= True)
for features,labels in dl_train:
print(features.shape)
print(labels.shape)
break
输出:
torch.Size([50, 3, 32, 32])
torch.Size([50, 1])
ImageFolder:
ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
root:在root指定的路径下寻找图片
transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
target_transform:对label的转换
loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象