PyTorch 1.0 数据加载与预处理

PyTorch 1.0 数据加载与预处理

    • 数据准备
    • 数据集的`类`
    • 组合 transforms
    • 迭代数据集
    • 最后,Torchvison
    • 更多

数据准备

在机器学习中,准备数据往往要耗费许多力气. PyTorch提供了许多工具来使数据加载变得更简单有用,同时可以让代码的可读性更高. 本文将演示怎样加载和预处理(包括数据增广)一个牛逼的数据集.
本文将用到下面两个工具包:

  • scikit-image:用于图像的输入输出和变换
  • pandas: 让我们更容易解析csv文件
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warning
import warnings
warnings.filterwarnings("ignore")

plt.ion() # interactive mode

我们要用到的是关于人脸姿态的数据集,每个人脸图像中包含了68个不同的标记点.
PyTorch 1.0 数据加载与预处理_第1张图片

  • 提示:
    从这里下载数据集到此目录下/data/faces/. 这个数据集实际上来自这个项目,它选自imagenetface类别的一小部分.

数据集中附带一个包含有标注信息的csv文件:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

接下来我们可以从csv中读出(N, 2)的标注信息数组,N指的是标注的数量.

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

Out:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

为了将标注点在图片上进行可视化,我们需要写一个简单的函数来可视化一个数据.

def show_landmarks(image, landmarks):
	"""
	Show image with landmarks
	"""
	plt.imshow(image)
	plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
	plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)), landmarks)
plt.show()

PyTorch 1.0 数据加载与预处理_第2张图片

数据集的

torch.utils.data.Dataset是一个表示数据集的抽象类. 我们的自定义数据集需要继承Dataset并重写下列方法:

  • __len__: 一般定义成len(dataset),用于返回数据集的size
  • __getitem__: 提供索引化的支持,使dataset[i]可以获得第i个样本

接下来创建一个人脸标注数据集的类,在__init__方法中读取csv,在__getitem__中读取图片,这样做的目的是因为这样每个图像仅需要被使用时才会被加载到存储中,提高内存的利用率.
数据集样本的格式是一个字典类型:{‘image’: image,'landmarks': landmarks}. 数据类中可以添加一个额外的参数transform,这样可以使样本可以在这里完成一些预处理的操作. 后面我们就能更具体地看到这个transform的用处.

class FaceLandmarksDataset(Dataset):
	"""Face Landmarks dataset """
	def __init__(self, csv_file, root_dir, transform=None):
		"""
		Args:
			csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
		"""
		self.landmarks_frame = pd.read_csv(csv_file)
		self.root_dir = root_dir
		self.transform = transform
	def __len__(self):
		return len(self.landmarks_frame)
	def __getitem__(self, idx):
		img_name = os.path.join(self.root_dir,
								self.landmarks_frame.iloc[idx, 0])
		image = io.imread(img_name)
		landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
		landmarks = landmarks.astype('float').reshape(-1, 2)
		sample = {'image': image, 'landmarks': landmarks}
		if self.transform:
			sample = self.transform(sample)
		
		return sample

接下来实例化上述的类,并将数据样本迭代一遍,输出前4个样本以及他们的标注点看一下.

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
									root_dir='data/faces')
fig = plt.figure()

for i in range(face_dataset):
	sample = face_dataset[i]
	print(i, sample['image'].shape, sample['landmarks'].shape)
	
	ax = plt.subplot(1, 4, i+1)
	plt.tight_layout()
	ax.set_title('Sample #{}'.format(i))
	ax.axis('off')
	show_landmarks(**sample)

	if i==3:
		plt.show()
		break

PyTorch 1.0 数据加载与预处理_第3张图片

Out:

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

Transforms
从上面显示的图片中可以发现一个问题,就是样本的尺寸不一样. 而一般来说,神经网络要求输入图像有固定的尺寸. 因此,我们可以使用Transforms写一些预处理代码:

  • Rescale: 调整图像的尺寸.
  • RandomCrop: 随机裁剪图像,这样可以做数据增广.
  • ToTensor: 将numpy类型的图像转换为torch类型(接着还需要调整坐标轴位置).
    我们将上述功能写成一个可调用的类,而不是简单的函数,这样就不必每次调用该函数时都得传递参数了. 我们仅需要实现__call__方法,或者,如果有必要的化,再实现一下__init__即可. 最后,只需要下面2行代码,即可使用刚已定义的类:
tsfm = Transform(params)
transformed_sample = tsfm(sample)

考验大家观察力的时候又到了,看看下面的代码中,图像和标注点是怎样被使用transforms的.

class Rescale(object):
	"""Rescale the image in a sample to a given size.
	Args:
		output_size(tuple or int): Desired output size. if tuple, output is matched to output_size.  if int, smaller of image edges is matched to output_size keeping aspect ratio the same.
	"""
	def __init__(self, output_size):
		assert isinstance(output_size, (int, tuple))
		self.output_size = output_size

	def __call__(self, sample):
		image, landmarks = sample['image'], sample['landmarks']

		h, w = image.shape[:2]
		if isinstance(self.output_size, int):
			if h > w:
				new_h, new_w = self.output_size * h / w, self.output_size
			else:
				new_h, new_w = self.output_size, self.output_size * w / h
		else:
			new_h, new_w = self.output_size
		
		new_h, new_w = int(new_h), int(new_w)
		img = transform.resize(image, (new_h, new_w))
		# h and w are swapped for landwarks because for images,
		# x and y axes are axis 1 and 0 respectively
		landmarks = landmarks * [new_w / w, new_h /h]
		
		return {'image': img, 'landmarks': landmarks}

class RandomCrop(object):
	"""Crop randomly the image in a sample
	Args:
		output_size (tuple or int): Desired output size. If int, square crop is make.
	"""
	def __init__(self, output_size):
		assert isinstance(output_size, (int, tuple))
		if isinstance(output_size, int):
			self.output_size = (output_size, output_size)
		else:
			assert len(output_size) == 2
			self.output_size = output_size
	def __call__(self, sample):
		image, landmarks = sample['image'], sample['landmarks']

		h, w = image.shape[:2]
		new_h, new_w = self.output_size
		
		top = np.random.randint(0, h - new_h)
		left = np.random.randint(0, w - new_h)
	
		image = image[top: top + new_h,
						left: left + new_w]
		
		landmarks = landmarks - [left, top]
		
		return {'image': image, '': landmarks}
class ToTensor(object):
	"""Convert ndarrays in sample to Tensors."""
	def __call__(self, sample):
		image, landmarks = sample['image'], sample['landmarks']

		# swap color axis because 
		# numpy image: H x W x C
		# torch image:C x H x W
		image = image.transpose((2, 0, 1))
		return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}

组合 transforms

all right,我们现在将上述转换应用到一个样本当中.
我们的目标是将图像的短边尺寸调整到256,然后随机地剪裁出一个224的正方形,我们需要用到RescaleRandomCrop两个变换类. torchvision.transforms.Compose是一个简单的可调用类,允许我们将上述功能合在一起.

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
								RandomCrop(224)])
# Apply each of the above transforms on  sample
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
	transformed_sample = tsfm(sample)
	
	ax = plt.subplot(1, 3, i+1)
	plt.tight_layout()
	ax.set_title(type(tsfm).__name__)
	transformed_sample = tsfm(sample)
	
plt.show()

PyTorch 1.0 数据加载与预处理_第4张图片

迭代数据集

我们接下来将上文所提到的所有方法合在一起. 总之就是,数据集每次都进行了如下操作:

  • 图像在程序一边运行一边读取
  • 读取到的图像做转换处理
  • 因为在转换时有随机裁剪,所以这样就有样本扩增的效果
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break

Out:

0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

然而,上面采用的for循环迭代的方法,我们丧失了很多必要的特性,特别是以下几点:

  • 数据批量化
  • 数据打乱
  • 使用multiprocessing并行加载数据

``是一个具有上述全部特性的迭代器,可以指定的一个有趣的参数是collate_fn.

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
                    landmarks_batch[i, :, 1].numpy(),
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

PyTorch 1.0 数据加载与预处理_第5张图片
Out:

0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

最后,Torchvison

本文展示了如何编写和使用数据集,转换,数据加载器. torchvision工具包提供了一些通用的数据集和转换方法. 你也许不需要编写自定义的类. 另一个常用的数据集类是ImageFolder,使用这个类需要将图片文件名放到一个文本中,如:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

这个转换方法可以使用PIL.Image库中的RandomHorizontalFlipScale方法. 你可以使用这些写一个数据加载器,如下:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

更多

请点着这里查阅官方网站.

你可能感兴趣的:(pytorch)