深度之眼Pytorch框架训练营第四期——数据读取机制中的Dataloader与Dataset

文章目录

      • 数据读取机制中的`Dataloader`与`Dataset`
        • 1、`Dataloader`
        • 2、`Dataset`
        • 3、实战:人民币二分类的数据读取
        • 4、数据读取源码分析

数据读取机制中的DataloaderDataset

前面学习到机器学习训练的五个步骤为:

  • 数据
  • 模型
  • 损失函数
  • 优化器
  • 迭代训练
    而这里的数据模块可以细分为四个子模块:
  • 数据收集:在进行实验之前,需要收集数据,数据包括原始样本和标签
  • 数据划分:有了原始数据之后,需要对数据集进行划分,把数据集划分为训练集、验证集和测试集;训练集用于训练模型,验证集用于验证模型是否过拟合,也可以理解为用验证集挑选模型的超参数,测试集用于测试模型的性能,测试模型的泛化能力;
  • 数据读取:数据读取的核心,细分为两个子模块——SamplerDataSet
  • Sample的功能是生成索引,也就是样本的序号
  • Dataset是根据索引去读取数据以及对应的标签
  • 数据预处理:把数据读取进来往往还需要对数据进行一系列的预处理,比如说数据的中心化,标准化,旋转或者翻转等等,pytorch中数据预处理是通过transforms进行处理

深度之眼Pytorch框架训练营第四期——数据读取机制中的Dataloader与Dataset_第1张图片

1、Dataloader

DataLoader(dataset, 
           batch_size=1,
           shuffle=False,
           sampler=None, 
           batch_sampler=None, 
           num_workers=0, 
           collate_fn=None, 
           pin_memory=False, 
           drop_last=False,
           timeout=0, 
           worker_init_fn=None, 
           multiprocessing_context=None)
  • 功能:构建可迭代的数据装载器
  • 参数:从上面的代码中可以看到,Dataloader的参数非常多,共有11个参数,但常用的就是下面五个:
  • datasetDataset类,决定数据从哪里读取及如何读取
  • batchsize:批大小
  • num_works:是否多进程读取数据
  • shuffle:每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
  • 解释:重点解释一下epochiterationbatchsize
  • epoch:所有训练样本都已输入到模型中,称为一个epoch,1个epoch表示过了1遍训练集中的所有样本
  • iteration:一批样本输入到模型中,称之为一个iteration(training step),每次迭代更新1次网络结构的参数
  • batchsize:批大小,表示一次迭代所使用的样本量,决定一个epoch中有多少个iteration
  • 举例:如果定义10000次迭代为1个epoch,若每次迭代的batchsize设为256,那么1个epoch相当于过了2560000个训练样本
  • drop_last作用:
样本总数 Batchsize drop_last Epoch
87 8 true = 10 iteration
87 8 false = 11 iteration

2、Dataset

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __add__(self, other)
        return ConcatDataset([self,other])
  • 功能:用来定义数据从哪里读取,以及如何读取的问题,Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
  • 参数
  • getitem:接收一个索引,返回一个样本

3、实战:人民币二分类的数据读取

对人民币二分类的数据进行读取,从以下三个方面了解Pytorch的读取机制:

  • 读哪些数据
  • 从哪读数据
  • 怎么读数据

  • 设置了数据读取的路径
    dataset_dir = os.path.join("/tmp/pytorch学习/WeekTwo/lesson-6/", "data", "RMB_data")
    split_dir = os.path.join("/tmp/pytorch学习/WeekTwo/lesson-6/", "data", "rmb_split")
    train_dir = os.path.join(split_dir, "train")
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")
  • 进行数据预处理
  • Resize是对数据进行缩放
  • RandomCrop是对数据进行裁剪(起到数据增强的效果)
  • ToTensor是对数据进行转换,把图像转换成张量数据
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

注意:训练集中用到了RandomCrop进行裁剪,但测试集中不需要要进行数据增强操作

  • 构建DatasetDataLoader
  • Dataset:必须是用户自己构建的,在Dataset中会传入两个主要参数:
    • data_dir:数据的路径(从哪里读取数据)
    • transform:数据预处理
  • Dataloader:构建数据迭代器,有两个主要参数:
    • Dataset:前面构建好的RMBDataset
    • batch_sizeshuffle=True表示每一个epoch中样本都是乱序的

Dataset构建代码:

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

RMBDataset的具体实现

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}   # 初始化部分
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):  # 函数功能是根据index索引去返回图片img以及标签label
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):   # 函数功能是用来查看数据的长度,也就是样本的数量
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):   # 函数功能是用来获取数据的路径以及标签
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info    # 有了data_info,就可以返回上面的__getitem__()函数中的self.data_info[index],根据index索取图片和标签

注意:构建了两个Dataset,一个用于训练,一个用于验证

Dataloader构建代码:

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

4、数据读取源码分析

  • 数据读取的三个步骤对应的源码如下:
步骤 源码实现
读哪些数据 sampler.py输出的Index
从哪读数据 Dataset中的参数data_dir
怎么读数据 Dataset的getitem()实现根据索引去读取数据
  • 流程图:首先在for循环中去使用DataLoader,进入DataLoader之后是否采用多进程进入DataLoaderlter,进入DataLoaderIter之后会使用sampler去获取Index,拿到索引之后传输到DatasetFetcher,在DatasetFetcher中会调用DatasetDataset根据给定的Index,在getitem中从硬盘里面去读取实际的ImgLabel,读取了一个batch_size的数据之后,通过一个collate_fn将数据进行整理,整理成batch_Data的形式,接着就可以输入到模型中训练
    深度之眼Pytorch框架训练营第四期——数据读取机制中的Dataloader与Dataset_第2张图片
  • 总结:读哪些是由Sampler决定的,从哪读是由Dataset决定的,怎么读是由getitem决定的

你可能感兴趣的:(深度之眼Pytorch框架训练营第四期——数据读取机制中的Dataloader与Dataset)