Internal Covariate Shift(ICS):数据尺度/分布异常,导致训练困难
H 11 = ∑ i = 0 n X i ∗ W 1 i D ( H 11 ) = ∑ i = 0 n D ( X i ) ∗ D ( W 1 i ) = n ∗ ( 1 ∗ 1 ) = n \begin{aligned} \mathrm{H}_{11}=& \sum_{i=0}^{n} X_{i} * W_{1 i} \\ \mathrm{D}\left(\mathrm{H}_{11}\right) &=\sum_{i=0}^{n} D\left(X_{i}\right) * D\left(W_{1 i}\right) \\ &=n *(1 * 1) \\ &=n \end{aligned} H11=D(H11)i=0∑nXi∗W1i=i=0∑nD(Xi)∗D(W1i)=n∗(1∗1)=n
std ( H 11 ) = D ( H 11 ) = n D ( H 1 ) = n ∗ D ( X ) ∗ D ( W ) = 1 \begin{array}{l} \operatorname{std}\left(\mathrm{H}_{11}\right)=\sqrt{\mathbf{D}\left(\mathrm{H}_{11}\right)}=\sqrt{n} \\ \mathbf{D}\left(\mathrm{H}_{1}\right)=\boldsymbol{n} * \boldsymbol{D}(\boldsymbol{X}) * \boldsymbol{D}(\boldsymbol{W})=\mathbf{1} \end{array} std(H11)=D(H11)=nD(H1)=n∗D(X)∗D(W)=1
D ( W ) = 1 n ⇒ std ( W ) = 1 n D(W)=\frac{1}{n} \Rightarrow \operatorname{std}(W)=\sqrt{\frac{1}{n}} D(W)=n1⇒std(W)=n1
1.Batch Normalization(BN)
2.Layer Normalization(LN)
3.Instance Normalization(IN)
4.Group Normalization(GN)
x ^ i ← x i − μ B σ B 2 + ϵ \widehat{x}_{i} \leftarrow \frac{x_{i}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^{2}+\epsilon}} x i←σB2+ϵxi−μB
$$
y_{i} \leftarrow \gamma \widehat{x}{i}+\beta \equiv \mathrm{N}{\gamma, \beta}\left(x_{i}\right)
均值和方差求取方式
起因:BN不适合用于变长的网络,如RNN
思路:逐层计算均值和方差
1.不再有running_mean 和 running_var
2.gamma 和 beta 为逐元素、逐特征的
主要参数:
normalized_shape:该层特征形状
eps:分母修正项
elementwise_affine:是否需要affine transform
起因:BN在图像生成(Image Ganeration)中不适用
思路:==逐Instance(channel)==计算均值和方差
计算方式 逐通道的
主要参数:
num_features:一个样本特征数量(最重要)
eps:分母修正项
momentum:指数加权平均估计当前mean/var
affine:是否需要affine transform
track_running_stats:是训练状态,还是测试状态
起因:小batch样本中,BN估计的值不准
思路:数据不够,通道来凑
1.不再有running_mean和running_var
2.gamma 和beta 为逐通道(channel)的
主要参数
num_groups 分组数 通产设为2的n次方
num_channels 通道数(特征数)
eps 分母修正项
affine 是否需要affine transform
BN LN IN GN 都是为了克服Internal Covariate shift(ICS)
减均值 除标准差 乘γ 加β
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ======================================== nn.layer norm
# flag = 1
flag = 0
if flag:
batch_size = 2
num_features = 3
features_shape = (2,2)
# features_shape = (3, 4)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 4D
# feature_maps_bs shape is [8, 6, 3, 4], B * C * H * W
ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=True)
# ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=False)
# ln = nn.LayerNorm([6, 3, 4])
# ln = nn.LayerNorm([6, 3])
output = ln(feature_maps_bs)
print("Layer Normalization")
print(ln.weight.shape)
print(feature_maps_bs[0, ...])
print(output[0, ...])
# ======================================== nn.instance norm 2d
# flag = 1
flag = 0
if flag:
batch_size = 3
num_features = 3
momentum = 0.3
features_shape = (2, 2)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 4D
print("Instance Normalization")
print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))
instance_n = nn.InstanceNorm2d(num_features=num_features, momentum=momentum)
for i in range(1):
outputs = instance_n(feature_maps_bs)
print(outputs)
# print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
# print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
# print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
# print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))
# ======================================== nn.grop norm
flag = 1
# flag = 0
if flag:
batch_size = 2
num_features = 4
# 设置分组数时一定是能被整除的 通常设置为2的N次幂
num_groups = 2 # 3 Expected number of channels in input to be divisible by num_groups
features_shape = (2, 2)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps * (i + 1) for i in range(batch_size)], dim=0) # 4D
# 分组数 有几个特征图
gn = nn.GroupNorm(num_groups, num_features)
outputs = gn(feature_maps_bs)
print("Group Normalization")
print(gn.weight.shape)
print(outputs[0])