PyTorch学习笔记——(6)数据加载Dataset和DataLoader的使用

目录

  • 1、模型中使用数据加载器的目的
  • 2、数据集类
    • 2.1 Dataset基类介绍:
    • 2.2 数据加载案例:
  • 3、迭代数据集

1、模型中使用数据加载器的目的

在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成个个的batch,同时还会对数据进行预处理

所以,接下来介绍pytorch中的数据加载的方法。

2、数据集类

2.1 Dataset基类介绍:

在torch中提供了数据集的基类torch.utils.data.Dataset, 继承这个基类,我们能够非常快速的实现对数据的加载。

torch.utils.data.Dataset的源码如下:

from torch.utils.data import Dataset

class Dataset(object):

	def __getitem__(self, index):
		raise NotImplementedError
	def __len__(se1f):
		raise NotImp lementedError
	def __add__(se1f, other):
		return ConcatDataset([self, other])

可知:我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:

  1. __1en__方法, 能够实现通过全局的len()方法获取其中的元素个数;
  2. __getitem__ 方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第 i i i条数据。

2.2 数据加载案例:

下面通过一个例子来看看如何使用Dataset来加载数据:

数据来源: 我的数据是甘肃省的气温数据,是文本数据,你可以随便找数据,都可以练习。

将数据利用pandas读取进来,然后实现自定义的数据集类,其实就是实现上面说的__1en__方法和__getitem__ 方法,下面是代码:

from torch.utils.data import Dataset, DataLoader
import torch
import pandas as pd

data_path = r"./data/wendu_8_4_9_2.csv"

# 完成数据集类
class MyDataset(Dataset):
    def __init__(self):
        self.data = pd.read_csv(data_path).values # DataFrame类型,通过values转换成numpy类型

    def __getitem__(self, index):
        """
        必须实现,作用是:获取索引对应位置的一条数据
        :param index:
        :return:
        """
        return MyDataset.to_tensor(self.data[index])

    def __len__(self):
        """
        必须实现,作用是得到数据集的大小
        :return:
        """
        return len(self.data)

    @staticmethod
    def to_tensor(data):
        """
        将ndarray转换成tensor
        :param data: 
        :return: 
        """
        return torch.from_numpy(data)

if __name__ == "__main__":
    data = MyDataset() # 实例化对象
    print(data[0]) # 取第1条数据
    print(len(data)) # 获取长度

3、迭代数据集

使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:

  • 批处理数据(Batching the data)
  • 打乱数据(Shuffling the data)
  • 使用多线程multiprocessing并行加载数据

在PyTorch中torch.utils.data.DataLoader提供了上述的所有方法

DataLoader使用示例:

from torch.utils.data import DataLoader
import torch
import pandas as pd


data = MyDataset() # 实例化对象,前面自定义的数据集类

# DataLoader就这一行,其实就是直接调用即可
data_loader = DataLoader(dataset=data, batch_size=2, shuffle=True, num_workers=2)


if __name__ == "__main__":
    for i in data_loader: # 可以迭代
        print(i)
        print('*'*50)

其中参数的含义:

  • 1、dataset:提前定义的dataset的实例;
  • 2、batch_size:传入数据的batch大小,常常是32、64、128、256’
  • 3、shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据;
  • 4、num_workers:加载数据的线程数。
  • 5、drop_last:bool类型,为真,表示最后的数据不足一个batch,就删掉

数据迭代器返回的结果如下:

PyTorch学习笔记——(6)数据加载Dataset和DataLoader的使用_第1张图片

这里有一点需要注意,如果我们同时获取我们自定义的数据集类MyDataset对象data的长度和 DataLoader对象data_loader的长度,我们会发现:
data_loader的长度是data的长度除以batch_size。如下面,我将batch_size设置为2,则如下:

print(len(data))
print(len(data_loader))

# 输出:
53280
26640

同时,需要注意,要是除不完,则向上取整,也就是说:如果我们的batch_zize=16,但是最后的数据只有1条,那这一条就算作一个batch,这个就是len(data_loader)的输出。

之后就可以把数据送入我们的模型了。

若有用,欢迎点赞,若有错,请指正,谢谢!!!

你可能感兴趣的:(PyTorch)