在使用 pytorch 构建深度学习相关的项目时,通常需要经过【模型结构】-【损失函数定义】-【数据设置】-【训练代码】-【log、验证、可视化与 checkpoints】。其中,【数据设置】往往因为项目/任务的不同,需要自定义合适的DataLoader(数据加载器)。
本文即将介绍 torch.utils.data 中的 Dataset 与 Dataloader 的基本用法,以 Unpaired Image-to-Image Translation 任务的非成对图像数据的加载为例,讲解 pytorch 如何自定义数据加载器。
下面的代码均在文件 dataset.py 中。
# -*- coding:utf-8 -*-
import torch.utils.data as data
import torchvision.transforms as transforms
import os
from PIL import Image
import random
import torch
import numpy as np
#### 01. Create a dataset
## BaseDataset
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return "BaseDataset"
def initialize(self, opt):
self.opt = opt
'''
定义一些公用的属性/函数;一般的,torch.utils.data.Dataset 本身已经包含了很多属性,如 __len__, __getitem__ 等。
一般我们会新增一个成员函数 name 和 initialize,分别用于:
1)name:没有任何意义,纯属装 B
2)在 pytorch 中,我们经常会使用到 parser,即一个能够从命令行赋予超参数值的辅助类,我们在代码中实例化它的一个对象为 "opt" ,而且,诸如 opt.img_size, opt.batch_size 这样的参数是与 data 相关的,所以我们通常会在这个函数引入 opt,并将它作为自己一个属性 self.opt,如此,我们就可以随时访问所有的超参数了。
'''
下面我们要自定义数据集 UnAlignedDataset。
首先看看我们的数据集长什么样:
,这是典型的 UIT 模型的数据集结构,可以知道涉及到 Dual training。每个子文件夹下都是一系列图像,且是不对齐的。
我们解释一下一些 opt 的参数:
opt.dataroot = '__data__/horse2zebra'
opt.mode = 'train' # 训练的时候是 train,测试的时候是 test,用来辅助分情况
opt.trainA = 'trainA'
opt.trainB = 'trainB'
opt.testA = 'testA'
opt.testB = 'testB'
opt.load_size = 288 # 读入图像大小
opt.crop_size = 256 # 将读入后的图像随机裁剪出的 patch 的大小
opt.input_nc = 3 # 图像输入的通道数:RGB-3,灰度图-1,CMYK-4等等,一般是前两种情况
下面我们的思路是:(1)在initialize中获取所有图像的路径以确保我们可以访问它们;(2)在initialize定义图像数据的基本处理流水线;(3)在__getitem__中定义返回怎么样的数据。
## SelfDataset
class UnAlignedDataset(BaseDataset):
## 重写 name,返回数据集的名字,一般用不到
def name(self):
return "UnAlignedDataset"
## 重写 initialize
'''
这里我们会根据传入的 opt,获取数据集的基本信息
'''
def initialize(self, opt):
self.opt = opt #-> 获取 opt
## get dir
self.dataroot = opt.dataroot #-> 根据 opt 里的 dataroot 得知数据集的位置
## get images #-> 构建图像子文件夹的路径
if opt.mode == 'train':
dir_A = os.path.join(opt.dataroot, opt.trainA)
dir_B = os.path.join(opt.dataroot, opt.trainB)
elif opt.mode == 'test':
dir_A = os.path.join(opt.dataroot, opt.testA)
dir_B = os.path.join(opt.dataroot, opt.testB)
A_paths = os.listdir(dir_A)
B_paths = os.listdir(dir_B)
self.length = min(len(A_paths), len(B_paths)) #-> 获取图像域 A 和图像域 B 的所有文件的文件名;并定义数据集大小为两个域的大小的较小一个,构建新的属性 self.length 存储它
## get full path
for i in range(len(A_paths)):
A_paths[i] = os.path.join(dir_A, A_paths[i])
for i in range(len(B_paths)):
B_paths[i] = os.path.join(dir_B, B_paths[i]) #-> 为了方便调用,先构建每张图像的完整路径(这里用相对路径)
self.A_paths = A_paths
self.B_paths = B_paths #-> 最后,为了在其他成员函数中可以直接访问,我们构建新的属性来存储它们
self.input_nc = self.opt.input_nc #-> 当然,对于一些重要的属性,我们可以从 opt. 中单独取出,下次用的时候就不需要经过 self.opt.xxx 调用,当然你也可以这么做,只不过不优雅
## define transform
transforms_list = [transforms.ToTensor(), #-> 从numpy到torch.tensor
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))] #-> 归一化到 -1.0~+1.0
self.transform = transforms.Compose(transforms_list)
#-> 定义数据处理的过程,注意,经过 torch.utils.data.Dataset 读入的图像就已经将像素值转换为浮点数,范围在 0~1.0 之间了,类型是 numpy 数组
## Dataset 类的核函数,用 len(dataset_object) 调用,返回数据集的大小
#-> Dataset 的大小与 DataLoader 的 batch_size 共同决定了一个 epoch 中 迭代次数的多少。即:length_of_dataset // batch_size
def __len__(self):
return self.length
## 这个核函数是 dataset 被调用时自己内部调用的,每次 dataset 用 next 获取下一个 batch 的数据的时候,内部会用连续的 batch_size 个索引来取值,并将最后的 batch_size 个结果在第〇个维度拼接在一起。
'''
举个栗子,在图像中,网络的输入一般是:(B, C, H, W);在视频中,输入一般是:(B, C, T, H, W)
而在 __getitem__ 中,我们通过定义它,让数据返回的数据是:(C, H, W)或者(C, T, H, W)的形式
'''
def __getitem__(self, index):
#-> 首先我们获取图像路径,注意由于我们的任务需要两个图像域的图像
#-> 我们根据索引对应数据大小的模来定位
A_pth = self.A_paths[index % self.length]
B_pth = self.B_paths[index % self.length]
#-> 读入图像
x_img = Image.open(A_pth).convert('RGB') #-> 读入图像
x_img = x_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC) #-> 双线性插值放缩到我们指定的大小(256x256)
x = self.transform(x_img) #-> 数据预处理
y_img = Image.open(B_pth).convert('RGB')
y_img = y_img.resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
y = self.transform(y_img)
## random crop 随机裁剪
h, w = x.size(1), x.size(2)
h_offset = random.randint(0, max(0, h - self.opt.crop_size - 1))
w_offset = random.randint(0, max(0, w - self.opt.crop_size - 1))
x = x[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]
y = y[:, h_offset:h_offset + self.opt.crop_size, w_offset:w_offset + self.opt.crop_size]
## expand to 4-dim tensor
if self.opt.input_nc == 1:
# RGB to gray
tmp_x = x[0, ...] * 0.299 + x[1, ...] * 0.587 + x[2, ...] * 0.114
x = tmp_x.unsqueeze(0) # (H,W) -> (C=1,H,W)
tmp_y = y[0, ...] * 0.299 + y[1, ...] * 0.587 + y[2, ...] * 0.114
x = tmp_y.unsqueeze(0) # (H,W) -> (C=1,H,W)
return {'A': x, 'B':y, 'A_pth': A_pth, 'B_pth': B_pth}
'''
返回什么样的数据是我们自定义的,后面我们会看到,我们怎么使用它:
for i, data in enumerate(dataset):
real_x = data['A']
real_y = data['B']
...
可以发现,DataLoader 只负责返回 batch 的数据(数据分不同部分时,各个部分单独作为 batch),数据的具体内容自定义的
'''
好了,我们可以发现,DataSet定义的是如何对要返回的单个数据做处理(像素值归一化、图像裁剪、颜色空间等,即所有一切我们在“数字图像处理”上学到的图像处理的技术都可以应用);我们发现,transform中有些是可以直接使用的;如果没有,可以自定义transform的处理函数,也可以像上面RGB转Gray那样直接写在__getitem__中!
前面说,DataSet定义的是返回的是单个数据,那么形成batch的任务、快速加载(分线程)的任务、每个epoch后shuffle(洗牌)数据集的任务等等,都是由DataLoader来完成的。
首先,我们定义一个基本的DataLoader,主要也是为了引入 opt 所以新增成员函数 initialize 。
#### 0.2 Create a Dataloader
## BaseDataLoader
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
def load_data(self):
return None
下面我们新定义 UnAlignedDataLoader 的数据加载器。
## Dataloader for self data
class UnAlignedDataLoader(BaseDataLoader):
def name(self):
return "UnAlignedDataLoader"
def initialize(self, opt):
BaseDataLoader.initialize(self, opt) # get the copy of opt->self.opt
# add dataset and nitialize it
self.dataset = UnAlignedDataset()
self.dataset.initialize(opt) # 因为 initialize 不是 torch.utils.data.Dataset 的核函数,所以我们需要手动调用它,才算完整初始化
# define a data loader
self.dataloader = data.DataLoader( # 调用 torch.utils.data.DataLoader,
self.dataset,
batch_size=opt.batch_size, # batch 的大小
shuffle=True, # 每个 epoch 后是否洗牌
num_workers=int(opt.n_threads) # 使用多少个进程加载数据
)
def load_data(self): # 返回整个数据加载器本身!!!非常重要
return self
def __len__(self): # 返回数据集的大小
return len(self.dataset)
def __iter__(self):
for _, d in enumerate(self.dataloader):
yield d # 核函数,用于每次以 batch 遍历整个数据集,即一个epoch
现在我们可以发现了,其实,许多都是套路!我们需要自定义的最主要的就是 UnAlignedDataset 中,在 initialize 中获取所有数据的路径;在 __getitem__ 中读入数据,并作自定义的处理(放缩、裁剪、像素值归一化等等),这些处理可以是transform中已有的,也可以是自定义的。
此外的其他三个类,结构与内容都基本不需要怎么改。
最后就是测试啦~
#### Test data loader
from config import parser
opt = parser.parse_args() ##-> 这是我自定义的,大家需要自己定义,结构大致如下:
'''
# config.py
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
...
'''
data_loader = UnAlignedDataLoader()
data_loader.initialize(opt)
data_set = data_loader.load_data()
for i, data in enumerate(data_set):
print(i, data['A'].size(), data['B'].size())
至此,pytorch 自定义简单的数据加载器遍历数据集的做法介绍到此结束,如有疏漏/错误,敬请指出!