Pytorch数据集加载

在加载训练数据进入内存的时候,一般会碰到两种情况:

(1)一种就是服务器的内存足够大,可以将数据集完全读入内存;

这种情况比较简单,可以将数据集完全读进内存后,进行tokenize,然后转成Tensor,放到Dataset(或者TensorDataset)类,再构造DataLoader。DataLoader类会负责每次读取出 batch_size 个样本,并且把每个样本的同类型feature放在同一个Tensor中。

(2)另一种情况就是内存没那么大,或者数据集太大,没办法一次性完全读进内存。

这种情况可以继承Dataset类,构造自定义的Dataset子类,重载子类中的__getitem__()函数(该函数可以每次返回一个样本,并且需要将样本的各个feature转成Tensor),然后实例化自定义的Dataset子类,再构造DataLoader。

所以本文就详细的讲解一下:Dataset、DataLoader以及其各自的子类。

一、Dataset类

Dataset类可以

你可能感兴趣的:(Pytorch数据集加载)