Pytorch基础(三)数据集加载及预处理

python提供了许多工具简化数据加载,使代码更具可读性。经常用到的包有scikit-image、pandas等,本文通过相关包进行数据加载和预处理相关简要介绍。

从此处(提取码:ilqy)下载数据集,数据存于"data/faces/"的目录中。这个数据集实际上是imagenet数据集标注为face的图片当中在dlib面部检测(dlib's pose estimation)表现良好的图片。下面以该数据集为例,对数据加载即预处理进行简要介绍。

下载数据集及显示样本

下面为下载数据集及显示其中某一样本的相关代码:

#!/usr/bin/env torch
# -*- coding:utf-8 -*-
# @Time  : 2021/2/3, 23:54
# @Author: Lee
# @File  : test.py

import os, torch
import pandas as pd
from skimage import io, transform  # 用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets

import warnings
warnings.filterwarnings("ignore")

plt.ion()  # interactive mode

# # 读取数据集  将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量
landmarks_frame = pd.read_csv("data/faces/face_landmarks.csv")

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].values
landmarks = landmarks.astype('float').reshape(-1, 2)


print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))


# 展示一张图片和它对应的标注点
def show_landmarks(image, landmarks):
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)


plt.figure()
show_landmarks(io.imread(os.path.join("data/faces/", img_name)), landmarks)
plt.show()
plt.pause(0)

打印结果如下:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

显示图形如下:

Pytorch基础(三)数据集加载及预处理_第1张图片

数据集类

torch.utils.data.Dataset是表示数据集的抽象类,因此自定义数据集应继承Dataset并覆盖以下方法*__len__实现len(dataset)返回数据集的尺寸。*__getitem__用来获取一些索引数据,例如dataset[i]中的(i)。

建立数据集类及显示部分样本

为面部数据集创建一个数据集类。在__init__中读取csv的文件内容,在__getitem__中读取图片。这么做是为了节省内存空间。只有在需要用到图片的时候才读取它而不是一开始就把图片全部放到内存中。

数据样本将按这样一个字典{'image':image, 'landmarks':landmarks}组织。该数据类将添加一个可选参数transform以方便对样本进行预处理,代码如下:

#!/usr/bin/env torch
# -*- coding:utf-8 -*-
# @Time  : 2021/2/4, 0:13
# @Author: Lee
# @File  : dataset_class.py

import os, torch
import pandas as pd
from skimage import io, transform  # 用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets

import warnings
warnings.filterwarnings("ignore")

plt.ion()  # interactive mode


# 展示一张图片和它对应的标注点
def show_landmarks(image, landmarks):
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)


# 数据集类
# 数据样本按这样一个字典{'image': image, 'landmarks': landmarks}组织。
# 添加一个可选参数transform 以方便对样本进行预处理
class FaceLandmarksDataset(Dataset):
    """人脸标记数据集"""
    def __init__(self, csv_file, root_dir, transform=None):
        """
        csv_file(string):带注释的csv文件的路径。
        root_dir(string):包含所有图像的目录。
        transform(callable, optional):一个样本上的可用的可选变换
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample


# 获取图片并可视化部分图片
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')
fig = plt.figure()
for i in range(len(face_dataset)):
    sample = face_dataset[i]
    print(i, sample['image'].shape, sample['landmarks'].shape)
    ax = plt.subplot(1, 4, i+1)
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break
plt.pause(0)

运行结果如下:

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

Pytorch基础(三)数据集加载及预处理_第2张图片

数据变换

通过上面的例子可知数据集中的图片并不是同样的尺寸。绝大多数神经网络都假定图片的尺寸相同。因此需要做预处理。这里以三个转换*Rescale(缩放图片),*RandomCrop(对图片进行随机剪裁),*ToTensor(把numpy格式的图片转换为torch格式图片,需要交换坐标轴)。

可以将它们写成可调用的类的形式而不是简单的函数,这样就不需要每次调用时传递一遍参数。只需要实现__call__方法,必要的时候实现__init__方法。

把这些整合起来以创建一个带组合转换的数据集。每次这个数据集被采样时,*即使地从文件中读取图片*对读取的图片应用转换*,由于其中一步是随机的(Randmpcrop),数据有所增强,现在可以用循环来对所有创建的数据执行同样的操作。

但是,对所有数据简单地使用for循环牺牲了很多功能,尤其是*批量处理数据(指定batch_size)*打乱数据(shuffle置True)*使用多线程(multiprocessingworker)并加载数据。

torch.utils.data.DataLoader是一个提供了上述所有这些功能的迭代器。下面使用的参数必须是清楚的。一个值得关注的参数是collate_fn,可以通过它来决定如何对数据进行批量处理,但是绝大多数情况下默认值就能运行良好。

代码如下:

#!/usr/bin/env torch
# -*- coding:utf-8 -*-
# @Time  : 2021/2/4, 0:28
# @Author: Lee
# @File  : data_preprogress.py

import os, torch
import pandas as pd
from skimage import io, transform  # 用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets

import warnings
warnings.filterwarnings("ignore")

plt.ion()  # interactive mode


# 展示一张图片和它对应的标注点
def show_landmarks(image, landmarks):
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)


class FaceLandmarksDataset(Dataset):
    """人脸标记数据集"""
    def __init__(self, csv_file, root_dir, transform=None):
        """
        csv_file(string):带注释的csv文件的路径。
        root_dir(string):包含所有图像的目录。
        transform(callable, optional):一个样本上的可用的可选变换
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample


class Rescale(object):
    """
    将样本中的图像重新缩放到给定大小
    Args:
    output_size(tuple或int):所需的输出大小。如果是元组,则输出为
    与output_size匹配。如果是int,则匹配较小的边缘到output_size保持横纵比相同
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size =output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 9 respectively
        landmarks = landmarks * [new_w / w, new_h / h]
        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """
    随机裁剪样本中的图像
    Args:
        output_size(tuple或int):所需的输出大小,如果是int, 方形裁剪是
    """
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        top = np.random.randint(0, h-new_h)
        left = np.random.randint(0, w-new_w)
        image = image[top: top+new_h, left:left+new_h]
        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """将样本中的ndarrays转换为Tensors"""
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        """
        交换颜色轴原因
        numpy包的图片时H*W*C 而torch包的图片是 C*H*W
        """
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}


# 辅助功能:显示批次
def show_landmark_batch(sample_batched):
    """show image with landmarks for a batch of samples"""
    images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i +1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size, s=10, marker='.', c='r')
        plt.title('Batch from dataloader')


if __name__ == '__main__':
    # 获取图片并可视化部分图片
    # 数据变换及torchvision.transforms.Compose组合操作
    face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                        root_dir='data/faces/')
    scale = Rescale(256)
    crop = RandomCrop(128)
    composed = transforms.Compose([Rescale(256),
                                   RandomCrop(224)])
    # 在样本上应用上述的变换
    fig = plt.figure()
    sample = face_dataset[65]
    print('数据变换及torchvision.transforms.Compose组合操作')
    for i, tsfrm in enumerate([scale, crop, composed]):
        transformed_sample = tsfrm(sample)
        ax = plt.subplot(1, 3, i + 1)
        plt.tight_layout()
        ax.axis('off')
        ax.set_title(type(tsfrm).__name__)
        show_landmarks(**transformed_sample)
    plt.show()
    plt.pause(0.5)

    # 迭代数据集
    transformed_dataset = FaceLandmarksDataset(csv_file="data/faces/face_landmarks.csv",
                                               root_dir="data/faces/",
                                               transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()]))

    for i in range(len(transformed_dataset)):
        sample = transformed_dataset[i]
        print(i, sample['image'].size(), sample['landmarks'].size())

        if i == 3:
            break

    dataloader = DataLoader(transformed_dataset, batch_size=4,
                            shuffle=True, num_workers=4)
    print('迭代数据集,batch_size=4')
    for i_batch, sample_batched in enumerate(dataloader):
        print(i_batch, sample_batched['image'].size(),
              sample_batched['landmarks'].size())

        if i_batch == 3:
            plt.figure()
            show_landmark_batch(sample_batched)
            plt.axis('off')
            plt.ioff()
            plt.show()
            break
    plt.pause(0)





运行结果如下:

数据变换及torchvision.transforms.Compose组合操作
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])
迭代数据集,batch_size=4
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

Pytorch基础(三)数据集加载及预处理_第3张图片

Pytorch基础(三)数据集加载及预处理_第4张图片

后记

上面的例子用函数实现了数据的部分预处理操作,主要包括使用数据集类(datasets),转换(transform)和数据加载器(DataLoader)。torchvision包提供了畅通的数据集类datasets和转换transforms,可能并不需要我们自己构造这些类。torchvision中海油一个更常用的数据集类ImageFolder。它假定了数据集是以如下方式构造的:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
...
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中'ants','bees'等是分类标签。在PIL.Image中可以使用类似的转换transform,例如RandHorizontalFlip,Scale。利用这些可以按如下方式创建一个数据集加载器(dataloader),以hymenoptera_data(提取码:2rvf)数据集为例:

#!/usr/bin/env torch
# -*- coding:utf-8 -*-
# @Time  : 2021/2/4, 1:06
# @Author: Lee
# @File  : data_transform.py

import os, torch
import pandas as pd
from skimage import io, transform  # 用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets

import warnings
warnings.filterwarnings("ignore")

plt.ion()  # interactive mode

data_transform = transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0,225])
    ])
hymenpotera_dataset = datasets.ImageFolder(root='data/hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenpotera_dataset,
                                              batch_size=4, shuffle=True,
                                              num_workers=4)

print(dataset_loader)

打印结果如下:

Debug可获得如下窗口:

Pytorch基础(三)数据集加载及预处理_第5张图片

 

你可能感兴趣的:(pytorch)