Pytorch基础及实战(4)——transforms数据增强

在AI领域的模型训练中通常会遇到模型过拟合问题,通常采取的办法就是数据增强处理,例如在图像处理中,数据增强是指对原始图像进行旋转、缩放、剪切、翻转等操作,以扩大训练数据集的规模,提高模型泛化能力,降低过拟合风险。

笔者在这里以深度学习框架Pytorch中的数据增强工具(transforms模块)为例介绍数据增强处理。torchvision.transforms是PyTorch中用于图像处理和数据增强的模块。它提供了许多函数,可以在图像上应用各种转换,例如裁剪、旋转、翻转、缩放、归一化等操作,从而生成更多变化的图像数据。

transforms模块中的函数可以分为两类:一类是针对PIL图像对象的操作函数,例如:Resize、RandomCrop、RandomHorizontalFlip等;另一类是对Tensor对象的操作函数,例如:Normalize、ToTensor等。下面笔者将分别介绍这些函数的原理和代码示例。

针对PIL图像对象的操作函数

Resize

Resize函数可以将图像缩放到指定大小。常用的缩放方法等比例缩放和非等比例缩放两种。其中如果预先目标大小,Resize函数会按照目标大小等比例缩放图像,如果预先仅指定了图片宽度或高度,则会进行非等比例缩放。

from torchvision import transforms
from PIL import Image

# 等比例缩放
transform = transforms.Compose([transforms.Resize((224, 224))])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)

# # 非等比例缩放
transform = transforms.Compose([transforms.Resize((224, 300))])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)
  • 等比例缩放
    Pytorch基础及实战(4)——transforms数据增强_第1张图片

  • 非等比例缩放

Pytorch基础及实战(4)——transforms数据增强_第2张图片

CenterCrop

CenterCrop函数可以从图像中心裁剪指定大小的区域。

from torchvision import transforms
from PIL import Image

transform = transforms.Compose([transforms.CenterCrop(224)])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)

Pytorch基础及实战(4)——transforms数据增强_第3张图片

RandomCrop

RandomCrop函数可以随机裁剪指定大小的区域。

from torchvision import transforms
from PIL import Image

transform = transforms.Compose([transforms.RandomCrop(224)])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)

Pytorch基础及实战(4)——transforms数据增强_第4张图片

RandomHorizontalFlip

RandomHorizontalFlip函数可以随机水平翻转图像。

from torchvision import transforms
from PIL import Image

transform = transforms.Compose([transforms.RandomHorizontalFlip()])
img = Image.open('1.png')
img.show(img)
img = transform(img)
img.show(img)

Pytorch基础及实战(4)——transforms数据增强_第5张图片

针对Tensor对象的操作函数

ToTensor

ToTensor函数可以将PIL图像对象转换为Tensor对象。

from torchvision import transforms
from PIL import Image

transform = transforms.Compose([transforms.ToTensor()])
img = Image.open('1.png')

img = transform(img)
print(img.size()) #torch.Size([3, 800, 1000])

Normalize

Normalize函数可以对Tensor对象进行归一化,以减少模型训练的时间,提高模型性能和稳定性。

from torchvision import transforms
from PIL import Image
import torch

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = Image.open('1.png')
img = transform(img)
print(torch.min(img), torch.max(img)) # tensor(-1.) tensor(1.)

上述代码将图像像素值归一化到[-1, 1]之间。

Compose函数

Compose函数则用于将多个transforms函数组合在一起,形成一个transforms的列表。在数据加载时,会按照列表中的顺序,依次对图像进行变换。

from torchvision import transforms

transform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = Image.open('1.png')
img = transform(img)
print(torch.min(img), torch.max(img))

上述代码中,我们通过Compose函数将CenterCrop、RandomHorizontalFlip、ToTensor和Normalize四个函数组合在一起,形成了一个transform对象。

你可能感兴趣的:(Pytorch原理及实战,深度学习,pytorch)