PyTorch学习笔记(4)torchvision

torchvision有4个功能模块:model、datasets、transforms和utils。利用datasets可以下载一些经典数据集,本次笔记主要记录如何使用datasets的ImageFolder处理自定义数据集,以及如何使用transforms对源数据进行预处理、增强等。

1. transforms

transforms提供了对PIL Image对象和Tensor对象的常用操作。

1)对PIL Image的常见操作如下。
Scale/Resize:调整尺寸,长宽比保持不变。
CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片,CenterCrop和RandomCrop在crop时是固定size,RandomResizedCrop则是random size的crop。
Pad:填充。
ToTensor:把一个取值范围是[0,255]的PIL.Image转换成Tensor。形状为(H,W,C)的Numpy.ndarray转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloatTensor。
RandomHorizontalFlip:图像随机水平翻转,翻转概率为0.5。
RandomVerticalFlip:图像随机垂直翻转。
ColorJitter:修改亮度、对比度和饱和度。

2)对Tensor的常见操作如下。
Normalize:标准化,即,减均值,除以标准差。
ToPILImage:将Tensor转为PIL Image。

如果要对数据集进行多个操作,可通过Compose将这些操作像管道一样拼接起来,类似于nn.Sequential。以下为示例代码:
这个东西会被送入你自定义的Dataset中!

transforms.Compose([
	# 将给定的PIL.Image进行中心切割,得到给定的size
	# size可以是tuple,(target_height, target_width)
	# size也可以是一个Integer, 切出来一个正方形
	transform.CenterCrop(10)
	# 切割中心点的位置随机选取
	transforms.RandomCrop(20, padding=0)
	# 将一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray
	# 转换为形状为(C,H,W),取值范围是[0,1]的torch.FloatTensor
	transforms.ToTensor()
	# 规范化到[-1, -1]
	transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5,0.5,0.5))
])

2. datasets.ImageFolder

当文件依据标签处于不同文件下时,如:
PyTorch学习笔记(4)torchvision_第1张图片
我们可以利用torchvision.datasets.ImageFolder来直接构造出dataset

loader = datasets.ImageFolder(path)
loader = data.DataLoader(dataset)

ImageFolder会将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签自动就是整数序列了。

下面我们利用ImageFolder读取不同目录下的图片数据,然后使用transforms进行图像预处理,预处理有多个,我们用compose把这些操作拼接在一起。然后使用DataLoader加载。对处理后的数据用torchvision.utils中的save_image保存为一个png格式文件,然后用Image.open打开该png文件,详细代码如下:

from torchvision import transforms, utils
from torchvision import datasets
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt


my_trans = transforms.Compose([
    transforms.RandomResizedCrop(224), #将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小
    transforms.RandomHorizontalFlip(), #图像水平翻转
    transforms.ToTensor()
])

train_data = datasets.ImageFolder(r'./data/torchvision_data', transform = my_trans)
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
for i_batch, img in enumerate(train_loader):
    if i_batch == 0:
        print(img[1])
        fig = plt.figure()
        grid = utils.make_grid(img[0])
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.show()
        utils.save_image(grid,'test02.png')
    break

这里我建立一个torchvision_data文件夹,把不同类型的图片放在不同的子文件夹下。
PyTorch学习笔记(4)torchvision_第2张图片
运行结果为:
在这里插入图片描述
可以看到图像尺寸缩小,并被水平翻转,最后拼接在一起。

参考文档:https://pytorch.org/docs/stable/torchvision/transforms.html

你可能感兴趣的:(Pytorch,python,pytorch)