pytorch基础语法学习:数据读取机制Dataloader与Dataset

来源:投稿 作者:阿克西
编辑:学姐

本章主要讲述数据模块,如何从硬盘中读取数据,对数据进行预处理、数据增强,转换为张量的形式输入到模型之中。

1 模块简介

本节主要学习数据模块当中的数据读取,数据模块通常还会分为四个子模块,数据收集、数据划分、数据读取、数据预处理。

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第1张图片

● 数据收集:收集原始样本和标签,如Img和Label。

● 数据划分:划分成训练集train,用来训练模型;验证集valid,验证模型是否过拟合,挑选还没有过拟合的时候的模型;测试集test,测试挑选出来的模型的性能。

● 数据读取:PyTorch中数据读取的核心是Dataloader。Dataloader分为Sampler和DataSet两个子模块。Sampler的功能是生成索引,即样本序号;DataSet的功能是根据索引读取样本和对应的标签。

● 数据预处理:数据的中心和,标准化,旋转,翻转等,在PyTorch中是通过transforms实现的。

这里主要学习第三个子模块中的Dataloader和Dataset。

2 DataLoader与Dataset

DataLoader和Dataset是pytorch中数据读取的核心。

2.1 DataLoader

torch.utils.data.DataLoader
DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器,每一次for循环就是从DataLoader中加载一个batchsize数据。

● dataset:Dataset类,决定数据从哪读取及如何读取

● batchsize:批大小

● num_works:是否多进程读取数据

● shuffle:每个epoch是否乱序

● drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

概念辨析:

● epoch:所有训练样本都已输入到模型中,称为一个epoch

● iteration:一批样本输入到模型中,称之为一个iteration

● batchsize:批大小,决定一个epoch中有多少个iteration

样本总数:80,batchsize:8 (样本能被batchsize整除)

● 1(epoch) = 10(iteration)

样本总数:87,batchsize=8 (样本不能被batchsize整除)

● drop_last = True:1(epoch) = 10(iteration)

● drop_last = False:1(epoch)= 11(iteration)

2.2 Dataset

torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且必须复写__getitem__()。

● Dataset:用来定义数据从哪里读取,以及如何读取的问题

● getitem:接收一个索引,返回一个样本

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __add__(self, other)
        return ConcatDataset([self,other])

3 人民币二分类

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第2张图片

要求:对第四套人民币1元和10元进行二分类,将人民币看作自变量x,类别看作因变量y,模型就是将自变量x映射到因变量y。

下面对人民币二分类的数据进行读取,从三个方面了解pytorch的读取机制,分别为读哪些数据、从哪读数据、怎么读数据。

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第3张图片

「读哪些数据:」在每一个iteration的时候应该读取哪些数据,每一个iteration读取一个batch大小的数据,假如有80个样本,那么从80个样本中读取8个样本,那么应该读取哪8个样本。

「从哪读数据 :」在硬盘当中,我们应该怎么找到对应的数据,在哪里设置参数。

「怎么读数据 :」从代码中学习。

3.1 数据集划分

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第4张图片

import os
import random
import shutil
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

if __name__ == '__main__':
    dataset_dir = os.path.abspath(os.path.join(BASE_DIR, "RMB_data"))
    split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
    train_dir = os.path.join(split_dir, "train")
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")

    if not os.path.exists(dataset_dir):
        raise Exception("\n{} 不存在,请下载 02-01-数据-RMB_data.rar 放到\n{} 下,并解压即可".format(
            dataset_dir, os.path.dirname(dataset_dir)))

    # 训练集、验证集与测试集所占比例
    train_pct = 0.8
    valid_pct = 0.1
    test_pct = 0.1

    for root, dirs, files in os.walk(dataset_dir):
        for sub_dir in dirs:

            imgs = os.listdir(os.path.join(root, sub_dir))
            imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
            random.shuffle(imgs)
            img_count = len(imgs)

            train_point = int(img_count * train_pct)
            valid_point = int(img_count * (train_pct + valid_pct))

            if img_count == 0:
                print("{}目录下,无图片,请检查".format(os.path.join(root, sub_dir)))
                import sys
                sys.exit(0)
            for i in range(img_count):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sub_dir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sub_dir)
                else:
                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
                                                                 img_count-valid_point))
            print("已在 {} 创建划分好的数据\n".format(out_dir))

划分好的数据集:

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第5张图片

3.2 人民币分类模型训练

3.2.1 导入包与参数设置

import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# 指定GPU进行训练
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import numpy as np
import random
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt

path_lenet = os.path.abspath(os.path.join(BASE_DIR, "model", "lenet.py"))
path_tools = os.path.abspath(os.path.join(BASE_DIR, "tools", "common_tools.py"))
assert os.path.exists(path_lenet), "{}不存在,请将lenet.py文件放到 {}".format(path_lenet, os.path.dirname(path_lenet))
assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))

import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)

from model.lenet import LeNet
from tools.my_dataset import RMBDataset   # RMBDataset类
# from tools.common_tools import set_seed
def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

rmb_label = {"1": 0, "100": 1}

3.2.2 Dataset类参数一:从哪读数据,设置硬盘中的路径

# ============================ step 1/5 数据 ============================
# 1、数据读取的三个问题中:从哪读数据,这里设置了硬盘中的路径
split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
if not os.path.exists(split_dir):
    raise Exception(r"数据 {} 不存在, "
                    r"回到lesson-06\1_split_dataset.py生成数据"
                    .format(split_dir))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

3.2.3 Dataset类参数二:数据预处理transform

# 数据预处理

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),           # 对数据进行缩放
    transforms.RandomCrop(32, padding=4),  # 对数据进行裁剪
    # 对数据进行转换,把图像转换成张量数据,并进行归一化,从0~255 → 0~1
    transforms.ToTensor(),
    # 数据标准化,将均值变为0,标准差变为1
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

transforms.Compose将一系列数据增强方法进行有序的组合,依次按照顺序对图像进行处理。

3.2.4 构建Dataset实例

Dataset必须是用户自己构建的,在Dataset中会传入两个主要参数,一个是data_dir,表示数据集的路径,即从哪读数据;第二个参数是transform,表示数据预处理。代码中构建了两个Dataset实例,一个用于训练,一个用于验证。

# 构建MyDataset实例,MyDataset必须是用户自己构建的
# 输入训练集所在路径data_dir,transform是数据预处理
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

3.2.5 构建DataLoader实例

有了Dataset就可以构建数据迭代器DataLoader,DataLoader传入的第一个参数是Dataset,也就是RMBDataset实例;第二个参数是batch_size;在训练集中的多了一个参数shuffle=True,作用是每一个epoch中样本都是乱序的。

# 构建DataLoder,shuffle=True,每一个epoch中样本都是乱序的
train_loader = DataLoader(dataset=train_data,
        batch_size=BATCH_SIZE,
        shuffle=True)

valid_loader = DataLoader(dataset=valid_data,
        batch_size=BATCH_SIZE)

3.2.6 模型、损失函数、优化器

# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ========================
criterion = nn.CrossEntropyLoss() # 选择损失函数

# ============================ step 4/5 优化器 ==========================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
# 设置学习率下降策略
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

3.2.7 开始训练

设置好数据、模型、损失函数和优化器之后,就可以进行模型的训练。

模型训练以epoch为周期,代码中先进行epoch的主循环,在每一个epoch当中会有多个iteration的训练,在每一个iteration当中去训练模型,每一次读取一个batch_size大小的数据,然后输入到模型中,进行前向传播,反向传播获取梯度,更新权值,接着统计分类准确率,打印训练信息。在每一个epoch会进行验证集的测试,通过验证集来观察模型是否过拟合。

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}"
                .format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            loss_val_epoch = loss_val / len(valid_loader)
            valid_curve.append(loss_val_epoch)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}"
                .format(epoch, MAX_EPOCH, j+1, len(valid_loader), 
                        loss_val_epoch, correct_val / total_val))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
# 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval - 1  

valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

输出结果

Training:Epoch[000/010] Iteration[010/010] Loss: 0.6472 Acc:58.75%
Valid:  Epoch[000/010] Iteration[002/002] Loss: 0.3726 Acc:50.00%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.2734 Acc:88.12%
Valid:  Epoch[001/010] Iteration[002/002] Loss: 0.0280 Acc:100.00%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.0395 Acc:99.38%
Valid:  Epoch[002/010] Iteration[002/002] Loss: 0.0006 Acc:100.00%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.4306 Acc:94.38%
Valid:  Epoch[003/010] Iteration[002/002] Loss: 0.2107 Acc:90.00%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.1142 Acc:98.12%
Valid:  Epoch[004/010] Iteration[002/002] Loss: 0.0531 Acc:100.00%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.0443 Acc:98.12%
Valid:  Epoch[005/010] Iteration[002/002] Loss: 0.0003 Acc:100.00%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.0070 Acc:100.00%
Valid:  Epoch[006/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.0036 Acc:100.00%
Valid:  Epoch[007/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.0001 Acc:100.00%
Valid:  Epoch[008/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.0003 Acc:100.00%
Valid:  Epoch[009/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第6张图片

3.3 RMBDataset类

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)
        # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    # 根据index索引返回数据集图片和对应标签
    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        # 0~255
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            # 在这里做transform,转为tensor等等
            img = self.transform(img)

        return img, label

    # 查看样本数量
    def __len__(self):
        return len(self.data_info)

    # 自定义函数,获取数据的路径以及标签
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))
        # 有了data_info,就可以返回上面的__getitem__()函数中的self.data_info[index],
        # 根据index索取图片和标签
        return data_info

3.4 断点调试

现在了解一下上面代码中RMBDataset中的具体实现。

pycharm小技巧:按住Ctrl,然后单击函数名或者类名就可以跳转到具体函数实现的位置。

在训练模型时,数据的获取是通过for循环获取的,从DataLoader迭代器中不停地去获取一个batchsize大小的数据。

1、下面通过代码的调试观察pytorch是如何读取数据的,在该处设置断点,然后执行debug。

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第7张图片

点击step into功能键,跳转到对应的函数中,发现是跳到了dataloader.py文件中的__iter__()函数;具体如下所示:

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第8张图片

这段代码是一个if的判断语句,其功能是判断是否采用多进程;如果采用多进程,有多进程的读取机制;如果是单进程,有单进程的读取机制;这里以单进程进行演示。

2、单击两次step into功能键

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第9张图片

单进程当中,最主要的是__next__()函数,在next中会获取index和data,回想一下数据读取中的三个问题,第一个问题是读哪些数据;__next__()函数就告诉我们,在每一个iteration当中读取哪些数据。

现在将光标对准_next_data函数中的第一行index=self._next_index(),点击功能区中的run to cursor,然后程序就会运行到这一行,点击功能区中的step into,进入到_next_index()函数中了解是怎么获得数据的index的;之后代码会跳到下面的代码中:

再点击一下step into就进入了sampler.py文件中,sampler是一个采样器,其功能是告诉我们每一个batch_size应该读取哪些数据;

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第10张图片

点击两次step out功能键

点击step over功能键,执行上面这段代码中的:

index = self._next_index()  # may raise StopIteration

就可以挑选出一个Iteration中的index,batch_size的值是16,则index列表长度为16:

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第11张图片

有了index之后,将index输入到Dataset当中去获取data,代码中会进入dataset_fetcher.fetch()函数。

3、点击功能区中的step_into,进入到fetch.py文件的_MapDatasetFetcher()类当中,在这个类里面实现了具体的数据读取,具体代码如下。代码中调用了dataset,通过输入一个索引idx返回一个data,将一系列的data拼接成一个list。

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第12张图片

点击step into查看一下这个过程,代码跳转到自定义dataset类RMBdataset()中的__getitem__()函数中,所以dataset最重要最核心的就是__getitem__()函数;

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第13张图片

这里已经实现了data_info()函数,对数据进行初步的读取,可以得到图片的路径和标签;然后通过Image.open来读取数据,这就实现了一个数据的读取,标签的获取。

之后点击step_out跳出该函数,会返回fetch()函数中;

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第14张图片

在fetch()函数return的时候会进入一个collate_fn(),它是数据的整理器,会将我们读取到的16个数据整理出一个batch的形式;得到数据和标签。

将光标放在return self.collate_fn(data) 处,点击run to cursor执行到当前位置,之后点击step over返回到单进程,点击step over,执行到下述代码,发现data已被打包,第一个元素是图像,第二个元素是标签。

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第15张图片

点击多次step out返回到最初训练模型读取数据的位置,执行step over可以发现循环中的data已被打包,第一个元素是图像,第二个元素是标签。

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第16张图片

3.5 总结

通过以上的分析,可以回答一开始提出的数据读取的三个问题:

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第17张图片

「1、读哪些数据?」

答:从代码中可以发现,index是从sampler.py中输出的,所以读哪些数据是由sampler得到的;

「2、从哪读数据?」

答:从代码中看,是从Dataset中的参数data_dir告诉我们pytorch是从硬盘中的哪一个文件夹获取数据。

「3、怎么读数据?」

答:从代码中可以发现,pytorch是从Dataset的getitem()中具体实现的,根据索引去读取数据。

「Dataloader读取数据很复杂,需要经过四五个函数的跳转才能最终读取数据」

为了简单,将整个跳转过程以流程图进行表示,通过流程图对数据读取机制有一个简单的认识。

简单描述一下流程图:

  1. 首先在for循环中去使用DataLoader;

  2. 进入DataLoader之后是否采用多进程进入单进程或者多进程的DataLoaderlter;

  3. 进入DataLoaderIter之后会使用sampler去获取Index;

  4. 拿到索引之后传输到DatasetFetcher;

  5. 在DatasetFetcher中会调用Dataset,Dataset根据给定的Index,在getitem中从硬盘里面去读取实际的Img和Label;

  6. 读取了一个batch_size的数据之后,通过一个collate_fn将数据进行整理;

  7. 整理成batch_Data的形式,接着就可以输入到模型中训练。

pytorch基础语法学习:数据读取机制Dataloader与Dataset_第18张图片

读哪些是由Sampler决定的index,从哪读是由Dataset决定的,怎么读是由getitem决定的。

关注下方《学姐带你玩AI》

回复“500”轻松获取AI必读200篇高分论文

码字不易,欢迎大家点赞评论收藏!

你可能感兴趣的:(深度学习干货,粉丝的投稿,人工智能干货,深度学习,人工智能,pytorch)