Pytorch——DataSet与DataLoader

在使用 pytorch 构建深度学习相关的项目时,通常需要经过【模型结构】-【损失函数定义】-【数据设置】-【训练代码】-【log、验证、可视化与 checkpoints】。其中,【数据设置】往往因为项目/任务的不同,需要自定义合适的DataLoader(数据加载器)。

本文即将介绍 torch.utils.data 中的 Dataset 与 Dataloader 的基本用法,以 Unpaired Image-to-Image Translation 任务的非成对图像数据的加载为例,讲解 pytorch 如何自定义数据加载器。

下面的代码均在文件 dataset.py 中。

(一)引入必须的包

# -*- coding:utf-8 -*-

import torch.utils.data as data
import torchvision.transforms as transforms
import os
from PIL import Image
import random
import torch
import numpy as np

(二)自定义数据集 Dataset

#### 01. Create a dataset
## BaseDataset
class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return "BaseDataset"

    def initialize(self, opt):
        self.opt = opt
'''
定义一些公用的属性/函数;一般的,torch.utils.data.Dataset 本身已经包含了很多属性,如 __len__, __getitem__ 等。

一般我们会新增一个成员函数 name 和 initialize,分别用于:
1)name:没有任何意义,纯属装 B
2)在 pytorch 中,我们经常会使用到 parser,即一个能够从命令行赋予超参数值的辅助类,我们在代码中实例化它的一个对象为 "opt" ,而且,诸如 opt.img_size, opt.batch_size 这样的参数是与 data 相关的,所以我们通常会在这个函数引入 opt,并将它作为自己一个属性 self.opt,如此,我们就可以随时访问所有的超参数了。
'''

下面我们要自定义数据集 UnAlignedDataset。

首先看看我们的数据集长什么样:

Pytorch——DataSet与DataLoader_第1张图片,这是典型的 UIT 模型的数据集结构,可以知道涉及到 Dual training。每个子文件夹下都是一系列图像,且是不对齐的。

我们解释一下一些 opt 的参数:


opt.dataroot = '__data__/horse2zebra'
opt.mode = 'train'            # 训练的时候是 train,测试的时候是 test,用来辅助分情况

opt.trainA = 'trainA'
opt.trainB = 'trainB'
opt.testA  = 'testA'
opt.testB  = 'testB'

opt.load_size = 288           # 读入图像大小
opt.crop_size = 256           # 将读入后的图像随机裁剪出的 patch 的大小
opt.input_nc  = 3             # 图像输入的通道数:RGB-3,灰度图-1,CMYK-4等等,一般是前两种情况

下面我们的思路是:(1)在initialize中获取所有图像的路径以确保我们可以访问它们;(2)在initialize定义图像数据的基本处理流水线;(3)在__getitem__中定义返回怎么样的数据。

 

## SelfDataset
class UnAlignedDataset(BaseDataset):
    ## 重写 name,返回数据集的名字,一般用不到
    def name(self):
        return "UnAlignedDataset"

    ## 重写 initialize
    '''
    这里我们会根据传入的 opt,获取数据集的基本信息
    '''
    def initialize(self, opt):
        self.opt = opt                                     #-> 获取 opt

        ## get dir 
        self.dataroot = opt.dataroot                       #-> 根据 opt 里的 dataroot 得知数据集的位置

        ## get images                                      #-> 构建图像子文件夹的路径
        if opt.mode == 'train':
            dir_A = os.path.join(opt.dataroot, opt.trainA)
            dir_B = os.path.join(opt.dataroot, opt.trainB)
        elif opt.mode == 'test':
            dir_A = os.path.join(opt.dataroot, opt.testA)
            dir_B = os.path.join(opt.dataroot, opt.testB)

        A_paths = os.listdir(dir_A)
        B_paths = os.listdir(dir_B)
        self.length = min(len(A_paths), len(B_paths))      #-> 获取图像域 A 和图像域 B 的所有文件的文件名;并定义数据集大小为两个域的大小的较小一个,构建新的属性 self.length 存储它

        ## get full path
        for i in range(len(A_paths)):
            A_paths[i] = os.path.join(dir_A, A_paths[i])
        for i in range(len(B_paths)):
            B_paths[i] = os.path.join(dir_B, B_paths[i])   #-> 为了方便调用,先构建每张图像的完整路径(这里用相对路径)
        self.A_paths = A_paths 
        self.B_paths = B_paths                             #-> 最后,为了在其他成员函数中可以直接访问,我们构建新的属性来存储它们

        self.input_nc = self.opt.input_nc                  #-> 当然,对于一些重要的属性,我们可以从 opt. 中单独取出,下次用的时候就不需要经过 self.opt.xxx 调用,当然你也可以这么做,只不过不优雅

        ## define transform
        transforms_list = [transforms.ToTensor(),                  #-> 从numpy到torch.tensor
                           transforms.Normalize((0.5, 0.5, 0.5),   
                                                (0.5, 0.5, 0.5))]  #-> 归一化到 -1.0~+1.0
        self.transform = transforms.Compose(transforms_list)
        #-> 定义数据处理的过程,注意,经过 torch.utils.data.Dataset 读入的图像就已经将像素值转换为浮点数,范围在 0~1.0 之间了,类型是 numpy 数组

    ## Dataset 类的核函数,用 len(dataset_object) 调用,返回数据集的大小
    #-> Dataset 的大小与 DataLoader 的 batch_size 共同决定了一个 epoch 中 迭代次数的多少。即:length_of_dataset // batch_size
    def __len__(self):
        return self.length

    ## 这个核函数是 dataset 被调用时自己内部调用的,每次 dataset 用 next 获取下一个 batch 的数据的时候,内部会用连续的 batch_size 个索引来取值,并将最后的 batch_size 个结果在第〇个维度拼接在一起。
    '''
    举个栗子,在图像中,网络的输入一般是:(B, C, H, W);在视频中,输入一般是:(B, C, T, H, W)
    而在 __getitem__ 中,我们通过定义它,让数据返回的数据是:(C, H, W)或者(C, T, H, W)的形式
    '''
    def __getitem__(self, index):
        #-> 首先我们获取图像路径,注意由于我们的任务需要两个图像域的图像
        #-> 我们根据索引对应数据大小的模来定位
        A_pth = self.A_paths[index % self.length]
        B_pth = self.B_paths[index % self.length]    

        #-> 读入图像
        x_img = Image.open(A_pth).convert('RGB')                                          #-> 读入图像
        x_img = x_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)     #-> 双线性插值放缩到我们指定的大小(256x256)
        x = self.transform(x_img)                                                         #-> 数据预处理  

        y_img = Image.open(B_pth).convert('RGB')
        y_img = y_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
        y = self.transform(y_img)

        ## random crop 随机裁剪
        h, w = x.size(1), x.size(2)
        h_offset = random.randint(0, max(0, h - self.opt.crop_size - 1))
        w_offset = random.randint(0, max(0, w - self.opt.crop_size - 1))
        x = x[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]
        y = y[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]

        ## expand to 4-dim tensor
        if self.opt.input_nc == 1:
            # RGB to gray
            tmp_x = x[0, ...] * 0.299 + x[1, ...] * 0.587 + x[2, ...] * 0.114
            x = tmp_x.unsqueeze(0)  # (H,W) -> (C=1,H,W)
            tmp_y = y[0, ...] * 0.299 + y[1, ...] * 0.587 + y[2, ...] * 0.114
            x = tmp_y.unsqueeze(0)  # (H,W) -> (C=1,H,W)

        return {'A': x, 'B':y, 'A_pth': A_pth, 'B_pth': B_pth} 
        '''
        返回什么样的数据是我们自定义的,后面我们会看到,我们怎么使用它:

        for i, data in enumerate(dataset):
            real_x = data['A']
            real_y = data['B']
            ...
        
        可以发现,DataLoader 只负责返回 batch 的数据(数据分不同部分时,各个部分单独作为 batch),数据的具体内容自定义的

        '''

好了,我们可以发现,DataSet定义的是如何对要返回的单个数据做处理(像素值归一化、图像裁剪、颜色空间等,即所有一切我们在“数字图像处理”上学到的图像处理的技术都可以应用);我们发现,transform中有些是可以直接使用的;如果没有,可以自定义transform的处理函数,也可以像上面RGB转Gray那样直接写在__getitem__中

(三)自定义数据加载器

前面说,DataSet定义的是返回的是单个数据,那么形成batch的任务、快速加载(分线程)的任务、每个epoch后shuffle(洗牌)数据集的任务等等,都是由DataLoader来完成的。

首先,我们定义一个基本的DataLoader,主要也是为了引入 opt 所以新增成员函数 initialize 。

#### 0.2 Create a Dataloader
## BaseDataLoader
class BaseDataLoader():
    def __init__(self):
        pass

    def initialize(self, opt):
        self.opt = opt

    def load_data(self):
        return None

下面我们新定义 UnAlignedDataLoader 的数据加载器。

## Dataloader for self data
class UnAlignedDataLoader(BaseDataLoader):
    def name(self):
        return "UnAlignedDataLoader"

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)  # get the copy of opt->self.opt

        # add dataset and nitialize it
        self.dataset = UnAlignedDataset()
        self.dataset.initialize(opt)          # 因为 initialize 不是 torch.utils.data.Dataset 的核函数,所以我们需要手动调用它,才算完整初始化

        # define a data loader
        self.dataloader = data.DataLoader(    # 调用 torch.utils.data.DataLoader,
            self.dataset,
            batch_size=opt.batch_size,        # batch 的大小
            shuffle=True,                     # 每个 epoch 后是否洗牌
            num_workers=int(opt.n_threads)    # 使用多少个进程加载数据
        )

    def load_data(self):                      # 返回整个数据加载器本身!!!非常重要
        return self

    def __len__(self):                        # 返回数据集的大小
        return len(self.dataset)

    def __iter__(self):
        for _, d in enumerate(self.dataloader):
            yield d                           # 核函数,用于每次以 batch 遍历整个数据集,即一个epoch

现在我们可以发现了,其实,许多都是套路!我们需要自定义的最主要的就是 UnAlignedDataset 中,在 initialize 中获取所有数据的路径;在 __getitem__ 中读入数据,并作自定义的处理(放缩、裁剪、像素值归一化等等),这些处理可以是transform中已有的,也可以是自定义的。

此外的其他三个类,结构与内容都基本不需要怎么改。

(四)测试

最后就是测试啦~

#### Test data loader
from config import parser
opt = parser.parse_args() ##-> 这是我自定义的,大家需要自己定义,结构大致如下:
'''
# config.py

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
...

'''

data_loader = UnAlignedDataLoader()
data_loader.initialize(opt)

data_set = data_loader.load_data()

for i, data in enumerate(data_set):
    print(i, data['A'].size(), data['B'].size())

Pytorch——DataSet与DataLoader_第2张图片输出如左图所示。 

至此,pytorch 自定义简单的数据加载器遍历数据集的做法介绍到此结束,如有疏漏/错误,敬请指出!

你可能感兴趣的:(深度学习框架)