在深度学习的问题中处理数据都会占据比较大的时间,只有把数据处理好了才有可能对模型进行训练、测试等后续工作。
PyTorch提供了很多用于让数据加载变得更加方便的工具,接下来我们就来学习一下怎么样处理那些PyTorch没有提供直接接口的数据。
在学习这个之前,首先要保证电脑上已经安装了下面这两样东西:
这篇文章内容还是比较多的,但认真看完应该就可以掌握各种数据集的处理了。
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
接下来我们要处理的数据集是关于脸部姿势的,每张图片都会被注释成这样,每张脸上都会有68各不同的标记点:
戳这里下载需要教程中用到的脸部数据集,跟数据集一起的还有一个注释文件face_landmarks.csv
。
直接打开如下图所示:
即每张图片都对应一个文件名和对应的N个脸部特征标记点。
在注释文件中的是N个坐标点,每个坐标点由两个横纵坐标组成。所以先用pandas工具把注释文件处理一下。
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
def show_landmarks(image, landmarks):
"""Show image with landmarks"""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
show_landmarks(io.imread(os.path.join('faces/', img_name)),
landmarks)
plt.show()
torch.utils.data.Dataset
是一个PyTorch用来表示数据集的抽象类。我们用这个类来处理自己的数据集的时候必须继承Dataset
,然后重写下面的函数:
__init__
函数中完成csv文件的读取工作;__getitem__
函数中完成图片的读取工作。这样是为了减小内存开销,只要在需要用到的时候才将图片读入。transform
,用来对图像做一些改变,具体的会在下面进行介绍。{‘image':image,'landmarks':landmarks}
那么现在我们就可以写出类的定义:
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
接下来我们对上面定义好的类做实例化,然后在数据样本上进行迭代。我们会打印前4个样本图像及其对应的坐标点。
face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces/')
fig = plt.figure()
for i in range(len(face_dataset)):
sample = face_dataset[i]
print(i, sample['image'].shape, sample['landmarks'].shape)
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)
if i == 3:
plt.show()
break
从上面显示的图片我们可以看到每张图片的大小都不一样,但往往我们在处理神经网络的输入图像的时候都希望它们有一个相对固定的大小。因此,我们需要一些对图像进行预处理的工作。
我们试着写一下这三个常用的变换功能:
Rescale
:重新调整图像大小;RandomCrop
:随机从图像中截取一部分;ToTensor
:将numpy类型表示的图像转换成torch表示的图像。我们用类而不是函数来实现以上这三个功能,主要是考虑到如果用函数的话,每次都需要传入参数,但是用类就可以省掉很多麻烦。我们只需要实现每个类的__call__
函数和__init__
函数。
下面是对这三个功能的实现:
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h,
left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),
'landmarks': torch.from_numpy(landmarks)}
假设我们现在需要将图像的较短边调整到256,然后从中随机截取224的正方形图像。我们就可以调用torchvision.transforms.Compose
将以上的Rescale
和RandomCrop
两个变换组合起来。
以下的代码段展示了分开进行变换以及用Compose组合进行变换的结果图
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample)
ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)
plt.show()
简单回顾一下:
dataset
类;如果你还记得的话,我们在之前定义dataset
的时候是有一个transform
参数的,但我们在第4节中是先取了样本数据,然后再进行变换操作,并没有将其作为参数传到dataset
中。所以我们现在要做的工作就是将所有的内容集成到一起。每次抽取一个样本,都会有以下步骤:
其实我们只要把Transform
的部分作为形参传入dataset
就可以了,其他的都不变。
然后用for循环来依次获得数据集样本。
transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
for i in range(len(transformed_dataset)):
sample = transformed_dataset[i]
print(i, sample['image'].size(), sample['landmarks'].size())
if i == 3:
break
以上我们已经实现了dataset
与transform
的合并,也实现了用for循环来获取每一个样本数据,好像事情就已经结束了。
但等等,真的结束了吗?emmmm,我们好像还落了什么事情,是的没错:
torch.utils.data.DataLoader
这个类为我们解决了以上所有的问题,是不是很腻害~
只要按照要求设置DataLoader
的参数即可:
如下所示即实现了DataLoader
函数的使用,及批样本数据的展示。
dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4)
# Helper function to show a batch
def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = \
sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
landmarks_batch[i, :, 1].numpy(),
s=10, marker='.', c='r')
plt.title('Batch from dataloader')
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size())
# observe 4th batch and stop.
if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break
torchvision
包提供了一些常用的数据集和转换函数。使用torchvision
甚至不需要自己写处理函数。
在torchvision
中最通用的数据集是ImageFolder
,它假设数据结构为如下:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
这里的root
指代根目录,ants bees
指的是不同的类标签,后面的是具体的图片名称。
当然它还提供了对PIL.Image
的常用操作,包括RandomHorizontalFlip Scale
等等。
以下为用torchvision
实现的超简化版本的数据处理方法:
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)
我们来整理一下整个实现思路哦~
主要分以下三种情况:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
Data Loading and Processing Tutorial
文章来源:
作者:与阳光共进早餐
链接:https://www.jianshu.com/p/6e22d21c84be
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。