在Pytorch中构建图片数据管道有两种方法:
第一种方法如下:
import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,datasets
transform_train = transforms.Compose(
[transforms.ToTensor()])
transform_valid = transforms.Compose(
[transforms.ToTensor()])
#%%
ds_train = datasets.ImageFolder("/media/yfh/hd/resources/databases/cifar2/train/",
transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("/media/yfh/hd/resources/databases/cifar2//test/",
transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
print(ds_train.class_to_idx) #{'0_airplane': 0, '1_automobile': 1}
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
#查看部分样本
from matplotlib import pyplot as plt
plt.figure(figsize=(8,8))
for i in range(9):
img,label = ds_train[i]
img = img.permute(1,2,0)
ax=plt.subplot(3,3,i+1)
ax.imshow(img.numpy())
ax.set_title("label = %d"%label.item())
ax.set_xticks([])
ax.set_yticks([])
plt.show()
# Pytorch的图片默认顺序是 Batch,Channel,Width,Height
for x,y in dl_train:
print(x.shape,y.shape) # torch.Size([50, 3, 32, 32]) torch.Size([50, 1])
break
第二种方法:
参考博客:https://blog.csdn.net/l8947943/article/details/103733473
1、继承torch.utils.data.Dataset并且重写_getitem_()
和__len__()
方法
import torch
import numpy as np
# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
# 初始化函数,得到数据
def __init__(self, data_root, data_label):
self.data = data_root
self.label = data_label
# index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
return data, labels
# 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
def __len__(self):
return len(self.data)
# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)
2、Dataloader并行加载数据
from torch.utils.data import DataLoader
# 读取数据
datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)
3、通过迭代器查看数据
for i, data in enumerate(datas):
# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
print("第 {} 个Batch \n{}".format(i, data))
结果如下
第 0 个Batch
[tensor([[0.9599, 0.8651, 0.2255, 0.0347, 0.0917, 0.9121, 0.1441, 0.9894, 0.9187,
0.8801, 0.0485, 0.8577, 0.8432, 0.0217, 0.2609, 0.7885, 0.4271, 0.6010,
0.4486, 0.4694],
[0.0324, 0.2408, 0.4294, 0.6394, 0.9968, 0.4153, 0.5748, 0.9075, 0.8704,
0.2500, 0.5978, 0.0943, 0.9280, 0.8045, 0.5619, 0.4407, 0.0798, 0.0098,
0.3712, 0.4186],
[0.5342, 0.7337, 0.1067, 0.2624, 0.1423, 0.3960, 0.0439, 0.3460, 0.0646,
0.8649, 0.3192, 0.4209, 0.8045, 0.5303, 0.5436, 0.8913, 0.5350, 0.4947,
0.3241, 0.1768],
[0.8492, 0.0950, 0.2038, 0.0865, 0.3746, 0.4050, 0.5040, 0.5224, 0.5192,
0.7546, 0.3538, 0.1554, 0.9970, 0.2397, 0.6701, 0.1990, 0.6772, 0.5123,
0.9840, 0.5672],
[0.7546, 0.3447, 0.0682, 0.8481, 0.7333, 0.3628, 0.6533, 0.1724, 0.6848,
0.5730, 0.6727, 0.4741, 0.9487, 0.4466, 0.8268, 0.5067, 0.5117, 0.5438,
0.1003, 0.5986],
[0.3786, 0.8163, 0.3150, 0.5195, 0.9077, 0.1611, 0.8182, 0.2060, 0.3715,
0.5046, 0.5230, 0.8975, 0.7656, 0.9408, 0.8220, 0.8867, 0.0290, 0.8946,
0.7680, 0.2677]], dtype=torch.float64), tensor([[1],
[0],
[0],
[1],
[0],
[1]])]
第 1 个Batch
[tensor([[0.4901, 0.5575, 0.2097, 0.1098, 0.5834, 0.0306, 0.1047, 0.4017, 0.7830,
0.9238, 0.3405, 0.2155, 0.3767, 0.2743, 0.8154, 0.3525, 0.5874, 0.8691,
0.0262, 0.2904],
[0.9268, 0.8384, 0.9948, 0.2149, 0.1508, 0.2278, 0.6399, 0.3555, 0.5254,
0.6366, 0.9150, 0.0842, 0.4703, 0.3684, 0.6052, 0.1764, 0.5499, 0.7318,
0.4513, 0.3531],
[0.5359, 0.9277, 0.2643, 0.3641, 0.3117, 0.7986, 0.7952, 0.6529, 0.4539,
0.4004, 0.4223, 0.2886, 0.9924, 0.5950, 0.9733, 0.4068, 0.1523, 0.4911,
0.7287, 0.4250],
[0.0345, 0.3635, 0.9745, 0.2807, 0.1577, 0.4595, 0.6639, 0.1265, 0.7047,
0.1411, 0.4033, 0.2724, 0.4256, 0.1492, 0.8040, 0.1352, 0.4836, 0.7783,
0.7087, 0.0935]], dtype=torch.float64), tensor([[1],
[1],
[0],
[1]])]
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
'''
dataset: 加载torch.utils.data.Dataset对象数据
batch_size: 每个batch的大小
shuffle:是否对数据进行打乱
drop_last:是否对无法整除的最后一个datasize进行丢弃
num_workers:表示加载的时候子进程数
'''
注:本文参考博客链接已贴出。