pytorch的dataset和dataloader

简单说,dataset是数据集,dataloader是加载数据集的工具

dataset

pytorch提供了多样化的dataset方法。

  1. 如果你的数据集比较小, x x x y y y都可以load到内存里,可以直接使用pytorch的torch.utils.data.TensorDataset:
import torch
from torch.utils import data

# build a toy dataset, with a sequence of x and y using y = sin(x) + noise
T = 1000
x = torch.arange(1, T + 1, dtype=torch.float32)
y = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))

dataset = data.TensorDataset(x, y) # takes in x/y pairs as parameters
  1. 如果数据集存放于硬盘里(文件夹),或者有其他需求,比如在训练时需要从zip或者hdf5中提取数据,那么需要重载pytorch提供的dataset类(下面是pytorch的官方例子):
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

上面这个例子有3个函数:
__init__:参数可以自定义,主要目的是告诉这个类,到什么地方去找数据。上面的这个例子应该是一个图像方面的深度学习应用。了img_dir来指定图片存放的路径,annotations_file指定标注信息(y)的文件名。transform传入图片在fit之前所需进行的变换,比如augmentation(增强)、normalization(归一化)等。target_transform传入标注信息所需进行的变换,这个通常较少用到。注意,上面的这些参数并不是固定的,用户在定义Dataset类的时候,根据自己的需要进行设置。
__len__:返回数据集的长度(数据个数)
__getitem__:传入idx,返回对应的 x/y 对

这三个方法基本上可以满足不同数据集的需求。

dataloader

dataloader定义了如何加载dataset。函数定义原型如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

dataloader的参数虽然比较多,但理解和使用比较简单,可以直接参考官方的说明:
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader 这里重点说几个参数:
shuffle:是否将数据集顺序打乱(洗牌)。通常,如果应用在训练集training_set上,shuffle=True,否则,shuffle=False
num_workers:可以理解为处理数据所用的并行线程数量。实际可以尝试数值0~8。大于这个数字效果可能不理想
pin_memory:这个参数我一开始不太理解,查阅了不少资料。下面“参考”里面的最后一条SO上的回答写的比较详细。这个参数形象的理解,是(用图钉)“钉住内存”。简单的说,就是使用固定地址的内存,避免在CPU-CPU间拷贝来拷贝去,影响性能。此外,pin_memory还可以允许对CPU内存和GPU内存的操作变为“异步”:对CPU内存的处理和延时等不会影响GPU内存的处理,二者可以并行。

参考

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
https://deeplizard.com/learn/video/kWVgvsejXsE#:~:text=Different%20num_workers%20Values%3A%20Results%20%20%20%20run,%20%204%20%2014%20more%20rows%20
https://stackoverflow.com/questions/55563376/pytorch-how-does-pin-memory-work-in-dataloader

你可能感兴趣的:(深度学习,pytorch,深度学习,神经网络)