pytorch实现BN、LN、GN

BN实现 

# coding=utf8
import torch
from torch import nn

# track_running_stats=False,求当前 batch 真实平均值和标准差,
# 而不是更新全局平均值和标准差
# affine=False, 只做归一化,不乘以 gamma 加 beta(通过训练才能确定)
# num_features 为 feature map 的 channel 数目
# eps 设为 0,让官方代码和我们自己的代码结果尽量接近
bn = nn.BatchNorm2d(num_features=3, eps=0, affine=False, track_running_stats=False)

# 乘 10000 为了扩大数值,如果出现不一致,差别更明显
x = torch.rand(10, 3, 5, 5)*10000 
official_bn = bn(x)

# 把 channel 维度单独提出来,而把其它需要求均值和标准差的维度融合到一起
x1 = x.permute(1,0,2,3).view(3, -1)
 
mu = x1.mean(dim=1).view(1,3,1,1)
# unbiased=False, 求方差时不做无偏估计(除以 N-1 而不是 N),和原始论文一致
# 个人感觉无偏估计仅仅是数学上好看,实际应用中差别不大
std = x1.std(dim=1, unbiased=False).view(1,3,1,1)

my_bn = (x-mu)/std

diff=(official_bn-my_bn).sum()
print('diff={}'.format(diff)) # 差别是 10-5 级的,证明和官方版本基本一致

GN 计算均值和标准差时,把每一个样本 feature map 的 channel 分成 G 组,每组将有 C/G 个 channel,然后将这些 channel 中的元素求均值和标准差。各组 channel 用其对应的归一化参数独立地归一化。Group Normalization (GN) 适用于占用显存比较大的任务,例如图像分割。对这类任务,可能 batchsize 只能是个位数,再大显存就不够用了。而当 batchsize 是个位数时,BN 的表现很差,因为没办法通过几个样本的数据量,来近似总体的均值和标准差。

import torch
from torch import nn


x = torch.rand(10, 20, 5, 5)*10000

# 分成 4 个 group
# 其余设定和之前相同
gn = nn.GroupNorm(num_groups=4, num_channels=20, eps=0, affine=False)
official_gn = gn(x)

# 把同一 group 的元素融合到一起
x1 = x.view(10, 4, -1)
mu = x1.mean(dim=-1).reshape(10, 4, -1)
std = x1.std(dim=-1).reshape(10, 4, -1)

x1_norm = (x1-mu)/std
my_gn = x1_norm.reshape(10, 20, 5, 5)

diff = (my_gn-official_gn).sum()

print('diff={}'.format(diff)) # 误差在 1e-4 级

BN 的一个缺点是需要较大的 batchsize 才能合理估训练数据的均值和方差,这导致内存很可能不够用,同时它也很难应用在训练数据长度不同的 RNN 模型上。Layer Normalization (LN) 的一个优势是不需要批训练,在单条数据内部就能归一化。把一个 batch 的 feature 类比为一摞书。LN 求均值时,相当于把每一本书的所有字加起来,再除以这本书的字符总数:C×H×W,即求整本书的“平均字”,求标准差时也是同理。

import torch
from torch import nn

x = torch.rand(10, 3, 5, 5)*10000

# normalization_shape 相当于告诉程序这本书有多少页,每页多少行多少列
# eps=0 排除干扰
# elementwise_affine=False 不作映射
# 这里的映射和 BN 以及下文的 IN 有区别,它是 elementwise 的 affine,
# 即 gamma 和 beta 不是 channel 维的向量,而是维度等于 normalized_shape 的矩阵
ln = nn.LayerNorm(normalized_shape=[3, 5, 5], eps=0, elementwise_affine=False)

official_ln = ln(x)

x1 = x.view(10, -1)
mu = x1.mean(dim=1).view(10, 1, 1, 1)
std = x1.std(dim=1,unbiased=False).view(10, 1, 1, 1)

my_ln = (x-mu)/std

diff = (my_ln-official_ln).sum()

print('diff={}'.format(diff)) # 差别和官方版本数量级在 1e-5

 

你可能感兴趣的:(pytorch实现BN、LN、GN)