pytorch系列4——数据Dataset&DataLoader

本文以pytorch1.10进行解读:torch — PyTorch 1.10 documentation

文本的操作在github上都有Shirley-Xie/pytorch_exercise · GitHub,且有运行结果。

1.Dataset和DataLoder介绍

1.1 Dataset

torch.utils.data.Dataset(*args**kwds)

        所有表示从键到数据样本映射的数据集都应该将其子类化。所有子类都应该覆盖__getitem__(),支持为给定的键获取数据样本。子类还可以选择性地覆盖__len__(),许多Sampler实现和DataLoader的默认选项都期望它返回数据集的大小。

        Dataset定义数据集的内容,类似于列表的数据结构,长度确定,能够用索引获取数据集中的元素。

        Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法,作用是接收一个索引, 返回一个样本。

1.2 DataLoader

    DataLoader定义了按batch加载数据集的方法,它是一个实现了`__iter__`方法的可迭代对象,每次迭代输出一个batch的数据。Dataset数据集相当于一种列表结构不同,IterableDataset相当于一种迭代器结构。 它更加复杂,一般较少使用。

    能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。

    在绝大部分情况下,用户只需实现Dataset的`__len__`方法和`__getitem__`方法,就可以轻松构建自己的数据集,并用默认数据管道进行加载。

函数签名如下:

torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
    multiprocessing_context=None,
)

常用dataset, batch_size, shuffle, num_workers,pin_memory, drop_last这六个参数。

  • - dataset : 数据集
  • - batch_size: 批次大小
  • - shuffle: 是否乱序
  • - sampler: 样本采样函数,一般无需设置。
  • - batch_sampler: 批次采样函数,一般无需设置。
  • - num_workers: 使用多进程读取数据,设置的进程数。
  • - collate_fn: 整理一个批次数据的函数。
  • - pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
  • - drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
  • - timeout: 加载一个数据批次的最长等待时间,一般无需设置。
  • - worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。

一般实现的代码如下:

ds = TensorDataset(torch.randn(1000,3),
                   torch.randint(low=0,high=2,size=(1000,)).float())
dl = DataLoader(ds,batch_size=4,drop_last = False)
features,labels = next(iter(dl))
print("features = ",features )
print("labels = ",labels )  

结果:

features =  tensor([[-0.3192, -1.7329, -1.7346],
        [-0.7792,  1.2145, -0.5208],
        [ 0.5105, -1.4158,  1.0757],
        [-1.3785, -1.3909, -0.7086]])
labels =  tensor([0., 0., 0., 1.])


2. Dataset和DataLoader操作步骤

获取一个batch数据的步骤

假定数据集的特征和标签分别表示为张量X和Y,数据集可以表示为(X,Y), 假定batch大小为m 。

  1. 首先我们要确定数据集的长度n。比如:n = 1000。确定数据集的长度由Dataset__len__方法实现的。数据是元组列表,也就是特征和标签。
  2. 然后我们从0到n-1的范围中抽样出m个数(batch大小)。假定m=4, 拿到的结果是一个索引列表,类似:indices = [1,4,8,9]。从n个中抽出m个数方法由DataLoader的 sampler和 batch_sampler参数指定的。也就是shuffle和drop_last两个参数影响。
  3. 接着我们从数据集中去取这m个数对应下标的元素。拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]。根据下标取数据集中的元素 是由 Dataset的 __getitem__方法实现的。
  4. 最后我们将结果整理成两个张量作为输出。拿到的结果是两个张量,类似batch = (features,labels) , 其中 features = torch.stack([X[1],X[4],X[8],X[9]])。labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])。DataLoader的参数collate_fn指定。

        总而言之,在一个确定数据集中,按照batch的大小确定索引,然后根据索引取出对应的数据。最后整理成特征和标签在一起的样子。

具体内部方法拆解如下:

# step1: 确定数据集长度 (Dataset的 __len__ 方法实现)
ds = TensorDataset(torch.randn(1000,3),
                   torch.randint(low=0,high=2,size=(1000,)).float())
print("n = ", len(ds)) # len(ds)等价于 ds.__len__()

# step2: 确定抽样indices (DataLoader中的 Sampler和BatchSampler实现)
sampler = RandomSampler(data_source = ds)
batch_sampler = BatchSampler(sampler = sampler, 
                             batch_size = 4, drop_last = False)
for idxs in batch_sampler:
    indices = idxs
    break 
print("indices = ",indices)

# step3: 取出一批样本batch (Dataset的 __getitem__ 方法实现)
batch = [ds[i] for i in  indices]  #  ds[i] 等价于 ds.__getitem__(i)
print("batch = ", batch)

# step4: 整理成features和labels (DataLoader 的 collate_fn 方法实现)
def collate_fn(batch):
    features = torch.stack([sample[0] for sample in batch])
    labels = torch.stack([sample[1] for sample in batch])
    return features,labels 

features,labels = collate_fn(batch)
print("features = ",features)
print("labels = ",labels)

结果:

n =  1000
indices =  [426, 137, 471, 292]
batch =  [(tensor([1.5614, 0.6875, 1.7250]), tensor(1.)), (tensor([ 0.2853, -1.4416, -0.5672]), tensor(1.)), (tensor([ 0.1800,  0.2652, -0.5301]), tensor(0.)), (tensor([-0.9303,  0.7461,  0.2575]), tensor(1.))]
features =  tensor([[ 1.5614,  0.6875,  1.7250],
        [ 0.2853, -1.4416, -0.5672],
        [ 0.1800,  0.2652, -0.5301],
        [-0.9303,  0.7461,  0.2575]])
labels =  tensor([1., 1., 0., 1.])

3.  使用Dataset创建数据集


Dataset创建数据集常用的方法有:

  • 使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。
  • 使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。
  • 继承 torch.utils.data.Dataset 创建自定义数据集。

此外,还可以通过

  •  torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。
  •  调用Dataset的加法运算符(+)将多个数据集合并成一个数据集。

此处代码是常见的自定义方法:

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info

        其中get_img_info做的是拿到数据的位置和标签,也就是元祖列表格式。 有了这个list,然后又给了data_info一个index, data_info[index] 就取出了某个(样本i_loc, label_i)。

        __getitem__()这个方法, 是不是很容易理解了, 第一行我们拿到了一个样本的图片路径和标签。然后第二行就是去找到图片,然后转成RGB数值。 第三行就是做了图片的数据预处理,最后返回了这张图片的张量形式和它的标签。

参考文章:

torch — PyTorch 1.10 documentation

GitHub - lyhue1991/eat_pytorch_in_20_days: Pytorch is delicious, just eat it! 系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)_dataloader 输入两个变量-CSDN博客

你可能感兴趣的:(pytorch代码实战,pytorch,人工智能,python,深度学习)