PyTorch数据加载方法

PyTorch数据加载方法

  • 数据集介绍
  • Dataset类的使用详解
  • DataLoader类的使用详解

数据集介绍

本文使用的数据集为开源的本文分类数据集SMS Spam Collection Data Set,下载地址为https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
数据集是从 Grumbletext 网站手动提取了 425 条 SMS 垃圾邮件的集合,由一个文本文件构成,其中每一行都是有一个类别和后面的原始消息构成。
PyTorch数据加载方法_第1张图片

Dataset类的使用详解

对与不同类型的数据集所需要的设置不同的Dataset,通常通过继承pytorch中的Dataset类进而构建模型训练需要的dataset。下面以SMS数据集为例,Dataset类的使用方法:

class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path, encoding='UTF-8').readlines()

    def __getitem__(self, index):
        # 获取索引对应位置的一条数据
        cur_line = self.lines[index].strip()  # strip取消换行符
        label = cur_line[:4].strip()
        content = cur_line[4:].strip()
        return label, content

    def __len__(self):
        # 返回数据总数
        return len(self.lines)

可以根据自己的数据集的实际情况来修改这三个方法:

  1. __init__方法可以用来设置读取数据集等初始化数据集的基本操作
  2. __getitem__方法通常用来根据索引来返回一条对应的数据内容
  3. __len__方法通常用来返回数据总数

使用如下代码展示一下读取后的效果:

my_dataset = MyDataset()
print(my_dataset[0])
print(len(my_dataset))

其中每一条数据是以一个元组的形式保存在Dataset数据集中,元组的第一个元素为标签,第二个元素为数据内容。
输出结果

DataLoader类的使用详解

DataLoader的主要作用是将Dataset处理后的数据集进行加载整合成batch用于后续训练,使用方法如下:

from torch.utils.data import DataLoader
data_loader = DataLoader(dataset=my_dataset, batch_size=2, shuffle=True, num_workers=2)

DataLoader类主要参数如下:

  1. dataset:经过Dataset类处理过的数据集
  2. batch_size:一个batch中包含几条数据
  3. shuffle:是否打乱顺序
  4. sample:用于自定义从数据集中抽取样本的策略与方法,每次返回一个随机的索引,与shuffle互斥,如果使用了shuffle,则无法使用 sample
  5. batch_sampler :与shuffle类似,但是每次返回一批随机的索引,与batch_sizeshufflesampledrop_last互斥
  6. num_work:多线程加速读取数据
  7. pin_memory:是否将数据放在dataloader返回前将Tensors复制到设备或者CUDA夹层内存中
  8. drop_last:用于判断是否放弃最后一个不完整的batch

使用for循环来展示一下效果:

for i in data_loader:
    print(i)
    break

当batch_size为2时可以看出这个列表中有两个元组,每个元组有条个数据。第一个元组存放标签,第二个元组存放着数据的内容。
效果1
在实际项目中通常使用enumerate方法在读取每一个batch内容的同时也返回其batch的索引:

for index, (label, content) in enumerate(data_loader):
    print(index, label, content)
    break

效果如下:
效果2

你可能感兴趣的:(pytorch,python,深度学习)