torchvision猫狗识别数据处理

在构建深度学习的过程中我们话大量的时间处理数据。一般我们会用到opencv,PIL,skimage等图像处理库。我们今天主要介绍torchvision图像处理库。

torch 中的数据

一般来说torch中的数据必须封装成torch.utils.data.Dataset

from torch.utils.data import Dataset
from os.path import join, exists, basename
from glob import glob
from torchvision import transforms as T

class DogCat(Dataset):
    def __init__(self, folder, transform = T.ToTensor()):
        self.images = glob(join(folder, '*.jpg'))
        self.transform = transform
    def __getitem__(self, item):
        img = Image.open(self.images[item])
        data = self.transform(img)
        target = 1 if basename(self.images[item])).startswith('dog') else 0
        return data, target
    def __len__(self):
        return len(self.images)

其中,transform参数用来传递图像处理功能。一般我们用torchvision.transforms来处理图像T.ToTensor()表示把PIL格式的图像处理成pytorch是识别的数据。

这里需要注意的一点是,ToTensor()会把图像从WxHxC变成CxWxH,同时默认把数据按照一定meanstd进行归一化。

我们还可以对PIL图片做更多的操作,如RandomCropAffine(放射变化)等, 同时还支持把这些操作按照一定的顺序进行组合,如下所示:

transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize([.5, .5, .5], [.5, .5, .5])  #这里表示mean=[R, G, B], std=[R, G, B]
])
  • T.Resize(416)(img)表示图片的缩放(或者增大),把图像的最小边缩放为416。
  • T.CenterCrop(416)(img)表示把图片从中剪裁,裁剪后的大小为416x416

你可能感兴趣的:(torchvision猫狗识别数据处理)