PyTorch数据集预处理函数transforms.ToTensor和Normalize

在学习网上现成的图片分类实例时,在源代码的开头总能看到以下代码

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

强迫症的我看不懂就要弄清楚它代表的含义。pytorch入门时会用到torchvision,torchvision是独立于pytorch的关于图像操作的工具库,包含了目前流行的数据集,模型结构和常用的图片转换工具。详细介绍点击这。这里的数据集是CIFAR10。

torchvision.transforms

transforms.ToTensor

torchvision 数据集的输出是范围 [0, 1] 的 PILImage 图像。我们将它们转换为标准化范围 [-1, 1] 的张量。转换就要用到上面的代码。

transforms.ToTensor()

ToTensor就是把一个取值范围是[0,255]的PIL.Image图像或者shape为(H,W,C)(Height,Width,Channel)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloadTensor

transforms.Normalize(mean, std)

Normalize就是对图像进行标准化,给定均值mean:(R,G,B) 方差std:(R,G,B),就会把Tensor正则化。

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

(0.5, 0.5, 0.5)就是RGB三通道上分别给定的均值和方差。标准化经过下面公式进行变换。

Normalized_image=(image-mean)/std

官方文档也有详细说明
比如前面的ToTensor转换为[0,1.0]的Tensor,然后最小值0就转换成(0-0.5)/0.5=-1,最大值1就转换成(1-0.5)/0.5=1,三通道R,G,B都如此经过Normalize()变换后,每个样本图像就符合了均值为0方差为1的标准正态分布。

为什么要这样处理?

数据经过归一化和标准化后可以加快梯度下降的求解速度,让数据远离Sigmoid激活函数的饱和区,这就是Batch Normalization等技术非常流行的原因,它使得可以使用更大的学习率更稳定地进行梯度传播,甚至增加网络的泛化能力。总而言之,提高训练速度,好处多多。

你可能感兴趣的:(深度学习pytorch版,pytorch,深度学习,人工智能)