使用pytorch准备自己的数据

前言

对于著名的数据集比如mnist,像Tensorflow、pytorch这样的流行框架已把它们集成到相关模块中,使用时一至几行简单的代码就能调用。但往往我们需要在自己的数据集上完成一些操作,这篇博客就旨在以单标签图像分类为例,浅谈一下如何使用pytorch准备自己的数据,如有错误,敬请斧正。


我所做的是一个室外图像的天气分类任务,类别只有sunny和cloudy两类。在这个例子中我们不需要提供额外的txt或其他形式的文件来将图片和标签对应起来,但需要将数据集按以下结构组织起来。

使用pytorch准备自己的数据_第1张图片

训练集和验证集(当然还可以有测试集)需要分开,每个split下面各个类别的图片也要分开,并且文件夹的名字最好就是类别名称。(关于使用pytorch进行回归或多标签分类任务本人还未研究过,这里就暂时不作介绍了)


下面就开始讲代码了,首先把全部代码贴出来,然后再细致解释一下。

import torch
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt 
import numpy as np
import os

# Data augmentation and normalization for training 
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = '/mount/temp/WZG/pytorch/Data/'

train_sets = datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train'])
train_loader = torch.utils.data.DataLoader(train_sets, batch_size=10, shuffle=True, num_workers=4)
train_size = len(train_sets)
train_classes = train_sets.classes

val_sets = datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val'])
val_loader = torch.utils.data.DataLoader(val_sets, batch_size=10, shuffle=False, num_workers=4)
val_size = len(val_sets)


# Visualize a few images
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    plt.imshow(inp)
    if title is not None:
        plt.title(title)


inputs, classes = next(iter(train_loader))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs, nrow=5)

imshow(out, title=[train_classes[x] for x in classes])

使用pytorch准备自己的数据时,主要用到的就是torch.utils.data模块以及torchvision的datasets和transforms模块,所以代码开始部分,首先把相关模块导入。

在代码的正文部分,首先定义对图像的一些变换以达到数据增强的目的,函数基本都是见名知意的,如果读者了解深度学习,那么对这些变换一定不会陌生。除了数据增强,还有一步操作是必要的,那就是把图像转换成pytorch需要的tensor格式。pytorch里面的tensor和Numpy的ndarray很像(但绝不等价),pytorch的官网在做介绍时,很多时候会和Numpy进行联系和对比,而tensor和ndarray也可以通过调用相关函数进行相互转换。数据转换成tensor后,数值范围会被自动压缩到0~1之间。这份代码中之后还使用transforms.Normalize()函数对数据进行了归一化,该函数包含两个list类型的参数,第一个参数为RGB三个通道各自的均值,第二个参数为相应的方差。另外,代码中的transforms.Compose()函数的作用是把所有这些变换组合到一起。这个例子中,我们把对训练集和验证集的变换操作写到了一个字典里,不过也完全可以将它们分开来写。

接下来,我们使用datasets.ImageFolder()函数来创建dataset对象,该函数的第一个参数是一个路径(比如这个例子中的训练集的路径),第二个参数是对这个路径下的图片要进行的变换操作(我们刚刚定义的那些变换)。再然后就是使用torch.utils.data.DataLoader()函数来定义数据的加载方式了,例子中对该函数给了4个参数,第一个是刚刚创建的dataset对象,第二个是batch的大小(即一个batch包含的样本数量),第三个参数是一个布尔值,代表是否进行shuffle,训练的话一般都会设为True。第四个参数num_workers表示开启多少个子进程进行数据的读取(并行读取),默认是0,即只使用主进程读数据。其余的更多参数请查阅官网doc。

train_size = len(train_sets)
train_classes = train_sets.classes

接下来的这两句是为了得到这个数据集的大小和所有的类别名称。得到的结果如下图所示:
使用pytorch准备自己的数据_第2张图片

可以看到我们创建的dataset对象的classes属性就对应着相关类别文件夹的名称。


为了测试数据是否能正确加载,我们定义一个imshow()函数来展示数据。imshow()函数中,首先对数据的维度进行transpose操作,这是因为tensor中,图片的shape是先通道再宽/高,如下:

inputs_shape

而显示图像需要先宽/高再通道。到这里读者应该有个疑问,那就是为什么batch的维度是4,而transpose函数里的shape却是3维的。这其实跟我们的显示形式有关,我后面会马上讲到。transpose之后就是归一化的逆操作了,最后使用plt.imshow()函数进行显示即可。


实际加载数据时,使用inputs, classes = next(iter(train_loader))一行代码就可以得到一个batch的数据,该函数会进行非重复采样,直至数据集被完整遍历一次。得到的inputs即一个batch的图像数据,而这里的classes是inputs各个样本对应的整数标签,自动从0开始,并且和类别名称的索引也是对应的。在这个例子中,0就对应cloudy,1对应sunny。再次把之前的结果贴一下,就更好理解了。
使用pytorch准备自己的数据_第3张图片

为了展示方便,我们使用下面这行代码对数据进行了一下处理:

out = torchvision.utils.make_grid(inputs, nrow=5)

这里的make_grid()函数就是把一个batch的数据重新排列成格的形式,例子中它的第一个参数即我们刚刚加载的batch,第二个参数代表一行放几个样本。处理之后,数据就变成了3维的,如下图所示:

outputs_shape

这就是为什么自定义的imshow()函数中的参数也是3维的。如果细心一点,会发现原数据和处理之后的数据有对不上的地方。原数据的shape为(10,3,224,224),batch size为10,处理时,我设置的是一行显示5张图片,也就是总共2行5列。那么处理后数据应该是(3,448,1120)才对,可从结果来看,处理后数据的高和宽都变大了。这主要是make_grid()函数本身搞的鬼,可能是为了显示时把各幅图片区分开,该函数会在图片之间以及整个grid的边缘自动加上线宽为2个像素的黑线,这是我查看了数据具体数值后发现的,图片之间的部分会有宽度为2,值全为0的间隔,那么从纵向来看,2*3(2行图片,中间有一条间隔,加grid上下两条边)=6=454-448。横向按此方法计算也对得上。

最后显示出的结果如下图所示:

使用pytorch准备自己的数据_第4张图片

从结果来看,我们加载的数据应该是正确的,并且从title可以看出,数据也确实经过了shuffle。

你可能感兴趣的:(pytorch)