【Pytorch】15. 图片数据的处理

之前我们用到的数据集都是人工合成的数据集,现在让我们用一些真实生活中的数据集,本节会用到一些 full-sized images 就像我们用手机拍出来的照片一样。

我们使用Kaggle上的dataset of cat and dog photos

【Pytorch】15. 图片数据的处理_第1张图片

文章目录

    • Loading Image Data
      • Transforms
      • Data Loaders
    • Data Augmentation

Loading Image Data

最简单的导入图片的方法就是使用 torchvision 中的 datasets.ImageFolder

dataset = datasets.ImageFolder('path/to/data', transform=transform)

其中 'path/to/data' 是文件的地址。 transform 是一系列的对图片的操作 transforms

图片地址储存的结构就像这样

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

每个分类都有单独的文件夹 (cat and dog) ,每张图片的label就是它们的文件夹名。. 例如这里 123.png 的class label 就是cat

这个数据集已经分好了类,并划分成训练和测试集,可以从这里下载from here.

Transforms

当数据load到imagefolder之后,现在就需要定义一些transform。因为图片的大小不一样,为了训练的时候大小一致,我们可以要么resize它们 transforms.Resize(), 要么就是crop它们 transforms.CenterCrop()transforms.RandomResizedCrop()。我们还需要再把这些图片转化成Tensor格式方便用于Pytorch,transforms.ToTensor()。 通常会把这些操作都合在一起,用 transforms.Compose()

transform = transforms.Compose([transforms.Resize(255),
                                 transforms.CenterCrop(224),
                                 transforms.ToTensor()])

还有很多其他的transform方式,可以看documentation.

Data Loaders

数据load到imagefolder之后,还需要将它们传到 DataLoaderDataLoader 的作用接受数据,然后返回batch形式的数据和对用的labels, 同时dataloder中也可以设置多种参数,例如batch的大小,是否要打乱shuffle

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

dataloder实际上是一个generator,也就是生成器,为了得到里面的数据,我们需要用到next()

# Looping through it, get a batch on each loop 
for images, labels in dataloader:
    pass

# Get one batch
images, labels = next(iter(dataloader))

Data Augmentation

训练神经网络的一个惯用策略就是对图片进行随机化,例如随机旋转,镜像,缩放,剪切。

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.5, 0.5, 0.5], 
                                                            [0.5, 0.5, 0.5])])

查看所有的transform方式the available transforms here

然后图片就会成这样:

【Pytorch】15. 图片数据的处理_第2张图片

完整的代码:

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt

import torch
from torchvision import datasets, transforms



def imshow(image, ax=None, title=None, normalize=True):
    """Imshow for Tensor."""
    if ax is None:
        fig, ax = plt.subplots()
    image = image.numpy().transpose((1, 2, 0))

    if normalize:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image + mean
        image = np.clip(image, 0, 1)

    ax.imshow(image)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', length=0)
    ax.set_xticklabels('')
    ax.set_yticklabels('')

    return ax

data_dir = 'Cat_Dog_data'

# TODO: Define transforms for the training data and testing data
train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor()]) 

test_transforms = transforms.Compose([transforms.Resize(255),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor()])


# Pass transforms in here, then run the next cell to see how the transforms look
train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=32)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)


# change this to the trainloader or testloader 
data_iter = iter(testloader)

images, labels = next(data_iter)
fig, axes = plt.subplots(figsize=(10,4), ncols=4)
for ii in range(4):
    ax = axes[ii]
    helper.imshow(images[ii], ax=ax, normalize=False)
    
    

查看完整代码参考
https://github.com/udacity/deep-learning-v2-pytorch.git中
intro-to-pytorch的Part 5

本系列笔记来自Udacity课程《Intro to Deep Learning with Pytorch》

全部笔记请关注微信公众号【阿肉爱学习】,在菜单栏点击“利其器”,并选择“pytorch”查看

【Pytorch】15. 图片数据的处理_第3张图片

你可能感兴趣的:(pytorch)