PyTorch|transforms.Normalize

在训练时对图片数据进行归一化可以在梯度下降算法中更好的寻优,这是普遍认为的。那么PyTorch中的transforms.Normalize,究竟做了什么,这是应该知道的。

来看下面这个公式:x取自一组数据C, mean是这组数据的均值,而std则为标准差

x=(x-mean)/std

这也意味着,Normalize,简单来讲,就是按照此公式对输入数据进行更新,

来看这样一段代码:

import numpy as npList1=np.array([1,2,3,4])mean=np.mean(List1)std=np.std(List1)List2=(List1-mean)/std>>> List1array([1, 2, 3, 4])>>> List2array([-1.34164079, -0.4472136 ,  0.4472136 ,  1.34164079])

List1经过Normalize后变为List2

那么对于图片数据,Normalize具体是如何工作的呢?

假如我们有四张图片的数据,借用前面文章的数据导入方式,导入数据:

import osfrom PIL import Imageimport numpy as npfrom torchvision import transformsimport torchpath="E:\\3-10\\dogandcats\\source"IMG=[]filenames=[name for name in os.listdir(path)]for i,filename in enumerate(filenames):    img=Image.open(os.path.join(path,filename))    img=img.resize((28,28))#将图片像素改为28x28    img=np.array(img)#将图像数据转为numpy    img=torch.tensor(img)#将numpy转换为tensor张量    img=img.permute(2,0,1)#将H,W,C转换为C,H,W    IMG.append(img)#得到图片列表IMGEND=torch.stack([ig for ig in IMG],dim=0)#堆叠tensor
​​​​​​​
>>> IMGEND.size()torch.Size([4, 3, 28, 28])

四张图片数据已经成功导入,并且已经转换为张量

获得r,g,b三个通道的均值​​​​​​

>>> mean=torch.mean(IMGEND,dim=(0,2,3),keepdim=True)>>> meantensor([[[[160.8753]],
         [[149.3600]],
         [[126.5810]]]])

获得r,g,b三个通道的标准差:​​​​​​​

>>> std=torch.std(IMGEND,dim=(0,2,3),keepdim=True)>>> stdtensor([[[[61.7317]],
         [[65.0915]],
         [[84.2025]]]])

归一化:

process=transforms.Normalize([160.8753, 149.3600, 126.5810],[61.7317, 65.0915, 84.2025])>>> dataend1=process(IMGEND)>>> dataend1tensor([[[[-1.3587, -0.9213, -0.7269,  ..., -0.3382, -0.3868, -0.4516],          [-1.4397, -0.8727, -0.6135,  ..., -0.1114, -0.1762, -0.2248],          [-1.8771, -1.3587, -0.9375,  ...,  0.1640,  0.0830, -0.1438],          ...,          [-1.9095, -1.8285, -1.8123,  ..., -2.1687, -2.2497, -2.2173],          [-1.9419, -1.8609, -1.8123,  ..., -2.3469, -2.4117, -2.2983],          [-1.9257, -1.8447, -1.8447,  ..., -2.3307, -2.3307, -2.2821]],
         [[-1.0502, -0.4357, -0.0055,  ...,  0.4246,  0.3785,  0.3325],          [-1.1424, -0.4203,  0.0406,  ...,  0.5783,  0.5475,  0.5168],          [-1.6340, -1.0656, -0.4664,  ...,  0.7626,  0.7319,  0.5629],          ...,          [-1.6340, -1.5572, -1.5418,  ..., -1.6955, -1.7723, -1.7876],          [-1.6801, -1.5879, -1.5418,  ..., -1.9413, -2.0027, -1.8491],          [-1.6186, -1.5726, -1.5726,  ..., -1.9259, -1.8952, -1.8645]],
         [[-0.4938,  0.0881,  0.5988,  ...,  1.0026,  0.9788,  0.9313],          [-0.5888,  0.0762,  0.6107,  ...,  1.0738,  1.0501,  1.0382],          [-1.0758, -0.5532,  0.1000,  ...,  1.1807,  1.1570,  1.0501],          ...,          [-0.9926, -0.9332, -0.9451,  ..., -1.3845, -1.4083, -1.3845],          [-1.0401, -0.9689, -0.9570,  ..., -1.4320, -1.4439, -1.3370],          [-0.9926, -0.9451, -0.9570,  ..., -1.4320, -1.4558, -1.3964]]],

        [[[-1.6827, -1.8609, -1.9095,  ..., -0.4192, -0.4840, -0.5002],          [-1.6989, -1.8285, -1.8933,  ..., -0.3868, -0.4678, -0.4516],          [-1.6989, -1.7961, -2.0877,  ..., -0.3868, -0.4192, -0.4516],          ...,          [ 0.7634,  0.8606,  0.8768,  ...,  0.9254,  0.9092,  0.9092],          [ 0.8120,  0.8930,  0.8930,  ...,  0.9416,  0.8930,  0.8930],          [ 0.8282,  0.9092,  0.9254,  ...,  0.9254,  0.8930,  0.8930]],
         [[-1.9413, -2.0334, -1.9720,  ..., -1.6340, -1.6340, -1.6340],          [-1.9413, -2.0181, -1.9720,  ..., -1.5879, -1.5572, -1.5572],          [-1.9413, -1.9874, -2.0488,  ..., -1.5726, -1.5265, -1.5265],          ...,          [ 0.5936,  0.7473,  0.7473,  ...,  0.8702,  0.8394,  0.8241],          [ 0.6397,  0.7780,  0.7780,  ...,  0.8702,  0.7780,  0.8087],          [ 0.7319,  0.8241,  0.8241,  ...,  0.8394,  0.8087,  0.7933]],
         [[-1.3608, -1.3845, -1.3370,  ..., -1.2539, -1.2301, -1.2301],          [-1.3608, -1.3845, -1.3252,  ..., -1.2183, -1.2064, -1.2064],          [-1.3608, -1.3727, -1.3964,  ..., -1.2064, -1.1826, -1.1826],          ...,          [ 0.5988,  0.7532,  0.7532,  ...,  0.8719,  0.8363,  0.8363],          [ 0.6700,  0.7888,  0.8007,  ...,  0.8719,  0.7650,  0.8126],          [ 0.7532,  0.8244,  0.8363,  ...,  0.8482,  0.8126,  0.8007]]],

        [[[ 0.6986,  0.8282,  0.7796,  ...,  0.1640,  0.0830,  0.1316],          [ 0.3908,  0.5204,  0.5852,  ...,  0.1964,  0.2774,  0.2126],          [ 0.4070,  0.4880,  0.6014,  ...,  0.0182,  0.3746,  0.2612],          ...,          [-0.3706, -0.6135, -0.4030,  ..., -0.2248, -0.2572, -0.2086],          [-0.4516, -0.6783, -1.0185,  ..., -0.3220, -0.3868, -0.4030],          [-0.5973, -0.5973, -1.0347,  ..., -0.3868, -0.4678, -0.5649]],
         [[ 0.6551,  0.7780,  0.6551,  ..., -0.2360, -0.2513,  0.1020],          [ 0.2249,  0.3478,  0.3939,  ..., -0.1899, -0.1438,  0.0252],          [ 0.2096,  0.2864,  0.3785,  ..., -0.3282, -0.0363, -0.0055],          ...,          [-0.1592, -0.5586, -0.6661,  ..., -0.0055, -0.0363, -0.0055],          [-0.2360, -0.5740, -1.1424,  ..., -0.0977, -0.1284, -0.1438],          [-0.3896, -0.4203, -0.9888,  ..., -0.1899, -0.2206, -0.2974]],
         [[-0.2088, -0.0782, -0.2919,  ..., -0.5770, -0.5413,  0.0050],          [-0.7670, -0.6720, -0.6720,  ..., -0.6363, -0.5532, -0.1257],          [-0.8145, -0.7432, -0.7195,  ..., -0.6720, -0.5770, -0.2919],          ...,          [-1.4202, -1.3845, -1.0282,  ..., -1.0282, -1.0045, -0.9689],          [-1.4320, -1.4202, -1.2658,  ..., -1.0758, -1.0401, -1.0282],          [-1.4320, -1.4202, -1.4202,  ..., -1.0758, -1.0995, -1.0995]]],

        [[[ 0.7958,  0.7958,  0.8120,  ...,  0.7958,  0.7958,  0.7958],          [ 0.7958,  0.8120,  0.8120,  ...,  0.8120,  0.8120,  0.7958],          [ 0.8120,  0.8120,  0.8120,  ...,  0.8120,  0.8120,  0.7958],          ...,          [ 0.7958,  0.7958,  0.7958,  ...,  0.8120,  0.7958,  0.7796],          [ 0.8444,  0.8444,  0.8606,  ...,  0.8930,  0.8930,  0.8768],          [ 0.8606,  0.8606,  0.8606,  ...,  0.8930,  0.8930,  0.8930]],
         [[ 0.9623,  0.9623,  0.9777,  ...,  0.9623,  0.9623,  0.9623],          [ 0.9777,  0.9777,  0.9777,  ...,  0.9777,  0.9777,  0.9623],          [ 0.9777,  0.9777,  0.9777,  ...,  0.9777,  0.9777,  0.9623],          ...,          [ 0.9623,  0.9623,  0.9623,  ...,  0.9623,  0.9623,  0.9470],          [ 1.0084,  1.0084,  1.0238,  ...,  1.0545,  1.0545,  1.0392],          [ 1.0238,  1.0238,  1.0238,  ...,  1.0545,  1.0545,  1.0545]],
         [[ 1.2638,  1.2638,  1.2757,  ...,  1.2638,  1.2638,  1.2638],          [ 1.2638,  1.2757,  1.2876,  ...,  1.2757,  1.2757,  1.2638],          [ 1.2995,  1.2995,  1.2995,  ...,  1.2757,  1.2757,  1.2638],          ...,          [ 1.2876,  1.2876,  1.2757,  ...,  1.2638,  1.2638,  1.2638],          [ 1.3232,  1.3232,  1.3114,  ...,  1.3351,  1.3351,  1.3232],          [ 1.3351,  1.3351,  1.3114,  ...,  1.3351,  1.3351,  1.3351]]]])

现在按变换公式编程进行计算:​​​​​​​

>>> enddata=(IMGEND-mean)/std>>> enddatatensor([[[[-1.3587, -0.9213, -0.7269,  ..., -0.3382, -0.3868, -0.4516],          [-1.4397, -0.8727, -0.6135,  ..., -0.1114, -0.1762, -0.2248],          [-1.8771, -1.3587, -0.9375,  ...,  0.1640,  0.0830, -0.1438],          ...,          [-1.9095, -1.8285, -1.8123,  ..., -2.1687, -2.2497, -2.2173],          [-1.9419, -1.8609, -1.8123,  ..., -2.3469, -2.4117, -2.2983],          [-1.9257, -1.8447, -1.8447,  ..., -2.3307, -2.3307, -2.2821]],
         [[-1.0502, -0.4357, -0.0055,  ...,  0.4246,  0.3785,  0.3325],          [-1.1424, -0.4203,  0.0406,  ...,  0.5783,  0.5475,  0.5168],          [-1.6340, -1.0656, -0.4664,  ...,  0.7626,  0.7319,  0.5629],          ...,          [-1.6340, -1.5572, -1.5418,  ..., -1.6955, -1.7723, -1.7876],          [-1.6801, -1.5879, -1.5418,  ..., -1.9413, -2.0027, -1.8491],          [-1.6186, -1.5726, -1.5726,  ..., -1.9259, -1.8952, -1.8645]],
         [[-0.4938,  0.0881,  0.5988,  ...,  1.0026,  0.9788,  0.9313],          [-0.5888,  0.0762,  0.6107,  ...,  1.0738,  1.0501,  1.0382],          [-1.0758, -0.5532,  0.1000,  ...,  1.1807,  1.1570,  1.0501],          ...,          [-0.9926, -0.9332, -0.9451,  ..., -1.3845, -1.4083, -1.3845],          [-1.0401, -0.9689, -0.9570,  ..., -1.4320, -1.4439, -1.3370],          [-0.9926, -0.9451, -0.9570,  ..., -1.4320, -1.4558, -1.3964]]],

        [[[-1.6827, -1.8609, -1.9095,  ..., -0.4192, -0.4840, -0.5002],          [-1.6989, -1.8285, -1.8933,  ..., -0.3868, -0.4678, -0.4516],          [-1.6989, -1.7961, -2.0877,  ..., -0.3868, -0.4192, -0.4516],          ...,          [ 0.7634,  0.8606,  0.8768,  ...,  0.9254,  0.9092,  0.9092],          [ 0.8120,  0.8930,  0.8930,  ...,  0.9416,  0.8930,  0.8930],          [ 0.8282,  0.9092,  0.9254,  ...,  0.9254,  0.8930,  0.8930]],
         [[-1.9413, -2.0334, -1.9720,  ..., -1.6340, -1.6340, -1.6340],          [-1.9413, -2.0181, -1.9720,  ..., -1.5879, -1.5572, -1.5572],          [-1.9413, -1.9874, -2.0488,  ..., -1.5726, -1.5265, -1.5265],          ...,          [ 0.5936,  0.7473,  0.7473,  ...,  0.8702,  0.8394,  0.8241],          [ 0.6397,  0.7780,  0.7780,  ...,  0.8702,  0.7780,  0.8087],          [ 0.7319,  0.8241,  0.8241,  ...,  0.8394,  0.8087,  0.7933]],
         [[-1.3608, -1.3845, -1.3370,  ..., -1.2539, -1.2301, -1.2301],          [-1.3608, -1.3845, -1.3252,  ..., -1.2183, -1.2064, -1.2064],          [-1.3608, -1.3727, -1.3964,  ..., -1.2064, -1.1826, -1.1826],          ...,          [ 0.5988,  0.7532,  0.7532,  ...,  0.8719,  0.8363,  0.8363],          [ 0.6700,  0.7888,  0.8007,  ...,  0.8719,  0.7650,  0.8126],          [ 0.7532,  0.8244,  0.8363,  ...,  0.8482,  0.8126,  0.8007]]],

        [[[ 0.6986,  0.8282,  0.7796,  ...,  0.1640,  0.0830,  0.1316],          [ 0.3908,  0.5204,  0.5852,  ...,  0.1964,  0.2774,  0.2126],          [ 0.4070,  0.4880,  0.6014,  ...,  0.0182,  0.3746,  0.2612],          ...,          [-0.3706, -0.6135, -0.4030,  ..., -0.2248, -0.2572, -0.2086],          [-0.4516, -0.6783, -1.0185,  ..., -0.3220, -0.3868, -0.4030],          [-0.5973, -0.5973, -1.0347,  ..., -0.3868, -0.4678, -0.5650]],
         [[ 0.6551,  0.7780,  0.6551,  ..., -0.2360, -0.2513,  0.1020],          [ 0.2249,  0.3478,  0.3939,  ..., -0.1899, -0.1438,  0.0252],          [ 0.2096,  0.2864,  0.3785,  ..., -0.3282, -0.0363, -0.0055],          ...,          [-0.1592, -0.5586, -0.6661,  ..., -0.0055, -0.0363, -0.0055],          [-0.2360, -0.5740, -1.1424,  ..., -0.0977, -0.1284, -0.1438],          [-0.3896, -0.4203, -0.9888,  ..., -0.1899, -0.2206, -0.2974]],
         [[-0.2088, -0.0782, -0.2919,  ..., -0.5770, -0.5413,  0.0050],          [-0.7670, -0.6720, -0.6720,  ..., -0.6363, -0.5532, -0.1257],          [-0.8145, -0.7432, -0.7195,  ..., -0.6720, -0.5770, -0.2919],          ...,          [-1.4202, -1.3845, -1.0282,  ..., -1.0282, -1.0045, -0.9689],          [-1.4320, -1.4202, -1.2658,  ..., -1.0758, -1.0401, -1.0282],          [-1.4320, -1.4202, -1.4202,  ..., -1.0758, -1.0995, -1.0995]]],

        [[[ 0.7958,  0.7958,  0.8120,  ...,  0.7958,  0.7958,  0.7958],          [ 0.7958,  0.8120,  0.8120,  ...,  0.8120,  0.8120,  0.7958],          [ 0.8120,  0.8120,  0.8120,  ...,  0.8120,  0.8120,  0.7958],          ...,          [ 0.7958,  0.7958,  0.7958,  ...,  0.8120,  0.7958,  0.7796],          [ 0.8444,  0.8444,  0.8606,  ...,  0.8930,  0.8930,  0.8768],          [ 0.8606,  0.8606,  0.8606,  ...,  0.8930,  0.8930,  0.8930]],
         [[ 0.9623,  0.9623,  0.9777,  ...,  0.9623,  0.9623,  0.9623],          [ 0.9777,  0.9777,  0.9777,  ...,  0.9777,  0.9777,  0.9623],          [ 0.9777,  0.9777,  0.9777,  ...,  0.9777,  0.9777,  0.9623],          ...,          [ 0.9623,  0.9623,  0.9623,  ...,  0.9623,  0.9623,  0.9470],          [ 1.0084,  1.0084,  1.0238,  ...,  1.0545,  1.0545,  1.0392],          [ 1.0238,  1.0238,  1.0238,  ...,  1.0545,  1.0545,  1.0545]],
         [[ 1.2638,  1.2638,  1.2757,  ...,  1.2638,  1.2638,  1.2638],          [ 1.2638,  1.2757,  1.2876,  ...,  1.2757,  1.2757,  1.2638],          [ 1.2995,  1.2995,  1.2995,  ...,  1.2757,  1.2757,  1.2638],          ...,          [ 1.2876,  1.2876,  1.2757,  ...,  1.2638,  1.2638,  1.2638],          [ 1.3232,  1.3232,  1.3114,  ...,  1.3351,  1.3351,  1.3232],          [ 1.3351,  1.3351,  1.3114,  ...,  1.3351,  1.3351,  1.3351]]]])

显然,两次结果一样,这也说明transforms.Normalize的实质就是使用该公式对输入数据进行变换。

同时,当transforms.Normalize接受的均值和标准差为待变换数据的均值和标准差时,按照此公式变换,得到的新的数据服从的分布一定是均值为0,标准差为1的分布

而当transforms.Normalize接受的均值和标准差不是待变换数据的均值和标准差时,所得的新数据均值未必为0,标准差也未必为1,仅仅是按照公式变换了数据而已。

就像这样:​​​​​​​

>>> process=transforms.Normalize([0.5, 0.6, 0.4],[0.36, 0.45, 0.45])>>> data=process(inputdata)

这里[0.5, 0.6, 0.4],[0.36, 0.45, 0.45]并不是inputdata的均值和标准差,是随意给的,仅仅是想对原数据进行变换,那么得到的新数据均值自然不一定为0,标准差也不一定为1。

当然,在我们对图片进行预处理时,往往会看到这两行代码一起出现:​​​​​​​

transform=transforms.Compose([transforms.ToTensor(),                            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])                            )

这里的transforms.ToTensor()的作用就是

将输入的数据变为张量,同时shape由 W,H,C ——> C,W,H, 同时,将所有数除以255,将数据归一化到[0,1]。

根据公式:x=(x-mean)/std

得:

(0-0.5)/0.5=-1

(1-0.5)/0.5=1

可以发现:新的数据分布为[-1,1],但是新的数据均值未必为0,同时标准差也未必为0,这点需要明白

之所以这样,是因为这里的[0.5,0.5,0.5],[0.5,0.5,0.5]并不一定就是原数据的均值和标准差。

你可能感兴趣的:(pytorch,pytorch,人工智能,python)