PyTorch学习笔记(3)Dataset和DataLoader

李宏毅深度学习网课作业3迟迟做不下去,发现pytorch方面要补的课还是太多,还是慢慢填坑吧。

utils.data包括Dataset和DataLoader

torch.utils.data.Dataset为抽象类
PyTorch学习笔记(3)Dataset和DataLoader_第1张图片
自定义数据集需要继承这个类,并实现两个函数,一个是__len__,另一个是__getitem__

前者提供数据的大小(size),后者通过给定索引获取数据和标签

__getitem__一次只能获取一个数据,所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取

首先定义获取数据集的类

该类继承基类Dataset,自定义一个数据集及对应标签。

class TestDataset(data.Dataset): # 继承Dataset
    def __init__(self):
        # 一些由2维向量表示的数据集
        self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]]) 
        # 这些是数据集对应的标签
        self.Label = np.asarray([0,1,0,1,2])
        
    def __getitem__(self, index):
        # 把numpy转换为tensor
        txt = torch.from_numpy(self.Data[index])
        label = torch.tensor(self.Label[index])
        return txt, label
    
    def __len__(self):
        return len(self.Data)
Test = TestDataset()
print(Test[2]) # 相当于调用__getitem__(2)
print(Test.__len__())

输出:

(tensor([2, 1], dtype=torch.int32), tensor(0, dtype=torch.int32))
5

以上数据以tuple返回,每次只返回一个样本。实际上,Dateset只负责数据的抽取,调用一次__getitem__只返回一个样本。如果希望批量处理(batch),还要同时进行shuffle和并行加速等操作,可选择DataLoader。

DataLoader的格式为:

data.DataLoader(
	dataset,                # 加载的数据集
	batch_size=1,			# 批大小
	shuffle=False,  		# 是否将数据打乱
	sampler=None,			# 样本抽样
	batch_sampler=None,
	num_workers=0,			# 使用多进程加载的进程数,0代表不适用多进程
	collate_fn=<function *>	# 如何将多个样本数据拼成一个batch
	pin_memory=False,		# 是否将数据保存在pin memory中,pin memory中的数据转到GPU会快一些
	drop_last=False,		# dataset中的数据个数可能不是batch_size的整数倍,drop_last为true会将多出来不足一个batch的数据丢弃
	timeout=0,
	worker_init_fn=None,
)

创建一个DataLoader

Test = TestDataset()
test_loader = data.DataLoader(Test, batch_size = 2, 
				    	shuffle = False, 
				    	num_workers=2, 
				    	drop_last = True)
for i, traindata in enumerate(test_loader):
    print('i:{}'.format(i))
    Data, Label = traindata
    print('data:',Data)
    print('Label:', Label)

输出:

i:0
data: tensor([[1, 2],
        [3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)
i:1
data: tensor([[2, 1],
        [3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)

从这个结果可以看出,这是批量读取。我们可以像使用迭代器一样使用它,比如对它进行循环操作。不过由于它不是迭代器,我们可以通过iter命令将其转换为迭代器。

dataiter = iter(test_loader)
imgs,labels = next(dataiter)

般用data.Dataset处理同一个目录下的数据。如果数据在不同目录下,因为不同的目录代表不同类别(这种情况比较普遍),使用data.Dataset来处理就很不方便。
不过,使用PyTorch另一种可视化数据处理工具(即torchvision)就非常方便,不但可以自动获取标签,还提供很多数据预处理、数据增强等转换函数。

参考资料:

  1. 《Python深度学习:基于PyTorch》
  2. https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/

你可能感兴趣的:(PyTorch,pytorch,机器学习,深度学习,python)