Pytorch学习笔记(9) 通过DataSet、DatasetLoader构建模型输入数据集

如何将我们准备好的数据放入模型中呢? Pytorch 给出的答案都在torch.utils.data 包中。

一、先看看所有的类

这个模块中方法并不多,所以让我们先全部列出来看看,看看名字猜猜功能。

  • Class torch.utils.data.Dataset 一个抽象类, 所有其他类的数据集类都应该是它的子类。所有子类应该重载lengetitem,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
  • Class torch.utils.data.DataLoader 数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
  • Class torch.utils.data.IterableDataset

  • Class torch.utils.data.TensorDataset

  • Class torch.utils.data.ConcatDataset(datasets)

  • Class torch.utils.data.ChainDataset(datasets)

  • Class torch.utils.data.Subset(dataset, indices)
    以上就是所有的类,之后的内容主要介绍Dataset和DatasetLoader这两个类,因为学会了这两个类,以后你可以按照任何你想的方式向模型中输入数据了。
    除了以上的CLASS,torch.utils.data 包中还提供了一些的数据采样的类和方法。相信大家以前都应该用过sklearn的train_test_split(),以下的其中一个方法也提供了类似的功能。

  • torch.utils.data.random_split(dataset, lengths) 按照给定的长度将数据集划分成没有重叠的新数据集组合。

  • CLASStorch.utils.data.Sampler(data_source)

  • CLASStorch.utils.data.SequentialSampler(data_source)

  • CLASStorch.utils.data.RandomSampler(...)

  • torch.utils.data.SubsetRandomSampler(...)

  • CLASStorch.utils.data.WeightedRandomSampler(...)

  • CLASStorch.utils.data.BatchSampler(sampler, batch_size, drop_last)

  • CLASStorch.utils.data.distributed.DistributedSampler(...)

二、Dataset和DatasetLoader

一般情况下,使用Dataset和DatasetLoader两个类已经可以完成大部分的数据导入。首先来看Dataset类。
在此对象中,必须重写以下两个方法。

def __getitem__(self, index)
      return  index对应的一条数据,可以是一张图,可以是一句话,总之 记住,一条数据。
     
def  __len__():
    return  带训练数据的总长度, 如果是dataframe, 直接len(df)即可,或者在init的时候传入了长度,直接返回

接下来看DataLoader 类

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False)

关键的几个参数:

  • dataset: 就是第一个介绍的Dataset, 实例化之后传入这里
  • batch_size: 这个不多说了
  • shuffle: 对于train_data, 一般选择true; 其他一般选择false
  • sampler: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据啦.

看看实例:
想到sklearn中提供了一些小数据集,使用鸢尾花(iris)的数据集:

def loaddata():
    iris_data = datasets.load_iris()
    return iris_data["data"], iris_data["target"]

class IrisDataset(Dataset):
    def __init__(self,irisdata,target):
        #   传入参数
        #   ndarray 类型的,可以是任何类型的
        self.irisdata = irisdata
        self.target = target
        self.lens = len(irisdata)

    def __getitem__(self, index):
        # index是方法自带的参数,获取相应的第index条数据
        return self.irisdata[index,:],self.target[index]

    def __len__(self):
        return self.lens

数据集就构架完成了,大家也可以通过DataFrame来处理数据。
然后结合DataLoader使用:

data,target = loaddata()
dataset_iris = IrisDataset(data,target)
train_loader = torch.utils.data.DataLoader(dataset_iris, batch_size=10,   shuffle=True, num_workers=4)

for i, (input, target) in enumerate(tqdm.tqdm(train_loader)):
        print(input.size())
        # 在这之后就可以进行训练了
输出

三、random_split 介绍

pytorch 中 random_split可以将实现sklearn 的 train_test_split类似的功能,大家可能注意到了,在上面的例子中只有训练数据,一般还需要有test set和valid set。
那么我们用random_split来划分数据集吧:

    data,target = loaddata()
    dataset_iris = IrisDataset(data,target)

    all_length = len(dataset_iris)
    train_size = int(0.80 * all_length)
    test_size = all_length - train_size

    train_dataset,test_dataset = torch.utils.data.random_split(dataset_iris,[train_size,test_size])

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=4)

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=4)

到这里就已经分好了,不过还是建议先通过其他方法提前分好。为了使每次结果都相同,可以设置好seed。

你可能感兴趣的:(Pytorch学习笔记(9) 通过DataSet、DatasetLoader构建模型输入数据集)