pytorch数据读取:DataLoader与Dataset

 

 

 


 

数据模块通常分为四个子模块,分别是收集、划分、读取和处理。

pytorch数据读取:DataLoader与Dataset_第1张图片

 

首先要收集数据,数据分为样本和标签 。

数据划分:要将数据划分为训练集、验证集和测试集。

数据读取,也就是今天要学习的DataLoader。DataLoader还分为Sampler和Dataset两个子模块。Sampler的作用是生成索引,也就是样本的序号。Dataset是根据索引读取图片和标签。

数据预处理。pytorch中是通过transforms实现的。

 

一、DataLoader与Dataset

主要学习数据读取模块

 

1、DataLoader

pytorch数据读取:DataLoader与Dataset_第2张图片

 

DataLoad是构建一个可迭代的数据装载器。迭代的时候,每一个for循环,每一个iteration都是从DataLoader中获取一个batchsize大小的数据。

共有11个参数,常用的有五个。

dataset:

batchsize:

num_works:

shuffle:

drop_last:

 

要理解drop_last,我们先理解Epoch、iteration、和batchsize的关系。

pytorch数据读取:DataLoader与Dataset_第3张图片

 

Epoch:所有样本输入一次。

所有样本会分成很多批输入,分成多少批就是有多少个iteration。批大小就是batchsize。

 

如果不能整除就需要用到drop_last参数。

 

pytorch数据读取:DataLoader与Dataset_第4张图片

 

2、Dataset

Dataset是用来定义数据从哪里读取以及如何读取的问题。

pytorch中给定的Dataset是一个抽象类。我们所有的自定义的Dataset都要继承它,并一定要重写它的__getitem()方法。

__getitem__()实现的功能就是接受索引,返回样本。

 

pytorch数据读取:DataLoader与Dataset_第5张图片

 

 

二、pytorch数据读取机制

下面从人民币分类的任务中学习pytorch读取数据的机制。

我们要训练一个模型,能够对第四套人民币中的1元和100元进行分类。

 

我们要解决下面三个问题。从哪读数据(怎么传入保存数据的地方)?怎么读?

 

 

pytorch数据读取:DataLoader与Dataset_第6张图片

 

首先在for循环中使用了DataLoader,经过DataLoader之后会根据是否使用多进程,进入单进程或多进程的DataLoaderIter。进入之后,会使用Sampler,获取Index。拿到索引之后,会给DatasetFetcher,DatasetFetcher会调用Dataset,Dataset根据给的索引在getitem从硬盘中读取图像和标签。读取了一个batchsize的数据之后,经过collate_fn整理成一个batch Data的形式。然后就可以输入到模型中训练了。

 

 

 

 

 

 

三、通过人民币的例子来熟悉数据读取的过程


 

 

1. 数据收集

 

假设我们已经收集好数据并已经打好标签。

 

|---data    //用来存放数据
|     |
|     |---RMB_data              //原始数据
               |
               |---1             //一元的人民币图像100张
               |---100         //100元人民币图像100张

 

如下图所示:

pytorch数据读取:DataLoader与Dataset_第7张图片

 

pytorch数据读取:DataLoader与Dataset_第8张图片

 

 

2、数据划分

 

编写data_split.py,划分数据集

# -*- coding: utf-8 -*-
"""
# @file name  : data_split.py
# @brief      : 将数据集划分为训练集,验证集,测试集
"""

import os
import random
import shutil

BASE_DIR = os.path.dirname(os.path.abspath(__file__)) #os.path.abspath(__file__):获取当前脚本的完整路径;os.path.dirname():去掉文件名,返回目录


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    dataset_dir = os.path.abspath(os.path.join(BASE_DIR, "data", "RMB_data"))
    split_dir = os.path.abspath(os.path.join(BASE_DIR, "data", "rmb_split"))
    train_dir = os.path.join(split_dir, "train")
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")

    if not os.path.exists(dataset_dir):
        raise Exception("\n{} 不存在,请下载 02-01-数据-RMB_data.rar 放到\n{} 下,并解压即可".format(
            dataset_dir, os.path.dirname(dataset_dir)))

    train_pct = 0.8
    valid_pct = 0.1
    test_pct = 0.1

    for root, dirs, files in os.walk(dataset_dir):
        for sub_dir in dirs:

            imgs = os.listdir(os.path.join(root, sub_dir))
            imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
            random.shuffle(imgs)
            img_count = len(imgs)

            train_point = int(img_count * train_pct)
            valid_point = int(img_count * (train_pct + valid_pct))

            for i in range(img_count):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sub_dir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sub_dir)
                else:
                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
                                                                 img_count-valid_point))
            print("已在 {} 创建划分好的数据\n".format(out_dir))

 

执行完后,会在data下生成rmb_split文件夹,下面有对应划分好的数据。

 

|---data    //用来存放数据
      |

      |---rmb_split             //执行了data_split.py之后的生成的文件夹,用来存放划分后的数据集。
                |
                |---train
                |      |
                |      |---1           //80张
                |      |---100       //80张
                |
                |---valit
                |      |
                |      |---1          //10张
                |      |---100      //10张
                |
                |---valid
                       |
                       |---1          //10张
                       |---100      //10张

 

 

 

 

 

 

 

 

pytorch数据读取:DataLoader与Dataset_第9张图片

 

 

你可能感兴趣的:(11,Python/DL/ML)