Pytorch学习笔记(1):pytoch中如何加载训练数据

1.为什么不需要自己写加载方法

pytorch中提供了两个类用于训练数据的加载,分别是 torch.utils.data.Dataset和 torch.utils.data.DataLoader 。不像torchvision中集合了很多常用的计算机视觉的常用数据集,作为在音乐信息检索这方面,数据集要自己设计加载方法。如果每次不同的数据集都要自己写函数加载,

  • 每次读取代码不能够重用,不同的数据读取代码不同
  • 自己写的加载函数也会有各种问题,比如说限制数据读取速度,或者当数据集太大,直接加载到字典或者列表中会很占用内存,数据读取阶段也会占用大量时间
  • 只能单线程读取数据

这次我做的实验需要加载歌曲的梅尔频谱,每个歌曲的片段为30秒,大约是一个1290*128大小的矩阵。所以这次我决定使用pytorch的Dataset类来加载数据。

2.Dataset类

class torch.utils.data.Dataset

这个抽象类代表了数据集,任何我们自己设计的数据集类都应该是这个类的子类,继承这个类,重写 _len_() 方法,这个方法是用来获得数据集的大小,和__getitem__()方法,这个方法用来返回数据集中索引值为0到len(dataset)的元素。

  • def __getitem__(self, index): 实现这个函数,就可以通过索引值来返回训练样本数据
  • def __len__(self): 实现这个函数,返回数据集的大小
class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

如果不重写这两个私有函数,就会触发错误。

3.定义自己的数据集类

于是我就针对自己的需求实现了以下的类。

class Fma_dataset(Dataset):
    # root 是训练集的根目录, mode可选的参数是train,test,validation,分别读取相应的文件夹
    def __init__(self, root, mode): 
        self.mode = mode
        self.root = root + "/fma_" + self.mode
        self.mel_cepstrum_path = self.get_sample(self.root)

    def __getitem__(self, index):
        sample = np.load(self.mel_cepstrum_path[index])
        data = torch.from_numpy(sample[0])
        target = torch.from_numpy(sample[1].astype(np.float32))
        return data, target

    def __len__(self):
        if self.mode == "train":
            return 23733  # 训练集大小
        elif self.mode == "validation":
            return 6780  # 验证集大小
        elif self.mode == "test":
            return 3390  # 测试集大小

    def get_sample(self, root):
        cepstrum = []
        for entry in os.scandir(root):
            if entry.is_file():
                cepstrum.append(entry.path)
        return cepstrum

4.DataLoader类

classtorch.utils.data.DataLoader(dataset, batch_size=1**,** shuffle=False**,** sampler=None**,** batch_sampler=None**,** num_workers=0**,** collate_fn=, pin_memory=False**,** drop_last=False**,** timeout=0**,** worker_init_fn=None**)**

仅仅有通过索引返回训练数据数不够的,我们还需要DataLoad类提供拓展功能。

  • 可以分批次读取:batch-size
  • 可以对数据进行shuffle操作
  • 可以用多个线程来读取数据

这个类我们不需要实现代码,直接调用,设置好参数就行了。

你可能感兴趣的:(深度学习,Pytorch)