nn.LayerNorm的具体实现方法(通过公式复现)

以下通过LayerNorm的公式复现了Layer Norm的计算结果,以此来具体了解Layer Norm的工作方式
公 式 : y = x − E [ x ] V a r [ x ] + ϵ ∗ γ + β 公式:y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y=Var[x]+ϵ xE[x]γ+β

1.只考虑最低维:每个维各自按公式计算即可,不和其他维度掺和

# LayerNorm
a = torch.tensor([[1,2,4,1],[6,3,2,4],[2,4,6,1]]).float()
nn.LayerNorm([4])(a)

结果:
tensor([[-0.8165, 0.0000, 1.6330, -0.8165],
[ 1.5213, -0.5071, -1.1832, 0.1690],
[-0.6509, 0.3906, 1.4321, -1.1717]])

# 使用公式实现,注意方差是有偏样本方差,系数是1/n
a = torch.tensor([[1,2,4,1],[6,3,2,4],[2,4,6,1]]).float()
b0 = (a[0] - torch.mean(a[0]))/((torch.var(a[0], unbiased=False)+1e-5)**0.5)#有偏样本方差
b1 = (a[1] - torch.mean(a[1]))/((torch.var(a[1], unbiased=False)+1e-5)**0.5)
b2 = (a[2] - torch.mean(a[2]))/((torch.var(a[2], unbiased=False)+1e-5)**0.5)
[b0, b1, b2]

结果:
[tensor([-0.8165, 0.0000, 1.6330, -0.8165]),
tensor([ 1.5213, -0.5071, -1.1832, 0.1690]),
tensor([-0.6509, 0.3906, 1.4321, -1.1717])]

2.考虑最低的2个维度:计算最低两维的12个元素的方差均值,然后按照公式算

a = torch.tensor([[1,2,4,1],[6,3,2,4],[2,4,6,1]]).float()
nn.LayerNorm([3,4])(a)

结果:
tensor([[-1.1547e+00, -5.7735e-01, 5.7735e-01, -1.1547e+00],
[ 1.7320e+00, 1.1921e-07, -5.7735e-01, 5.7735e-01],
[-5.7735e-01, 5.7735e-01, 1.7320e+00, -1.1547e+00]],
grad_fn=)

a = torch.tensor([[1,2,4,1],[6,3,2,4],[2,4,6,1]]).float()
b = (a - torch.mean(a))/((torch.var(a, unbiased=False)+1e-5)**0.5)

结果:
tensor([[-1.1547, -0.5773, 0.5773, -1.1547],
[ 1.7320, 0.0000, -0.5773, 0.5773],
[-0.5773, 0.5773, 1.7320, -1.1547]])

3. 多维的例子复现

import torch
from torch import nn
# 机算
x = torch.rand(2,3,4,5)
layer = nn.LayerNorm(5)
out1 = layer(x)
# 手算
mean = x.mean(axis=3).reshape(-1,x.shape[1],x.shape[2],1)
var = x.var(axis=3,unbiased=False).reshape(-1,x.shape[1],x.shape[2],1)
out2 = (x-mean)/((var+1e-5)**0.5)

# 机算
x = torch.rand(2,3,4,5)
layer = nn.LayerNorm([4,5])
out1 = layer(x)
# 手算:
mean = x.reshape(x.shape[0],x.shape[1],-1).mean(axis=2).reshape(-1,x.shape[1],1,1)
var = x.reshape(x.shape[0],x.shape[1],-1).var(axis=2,unbiased=False).reshape(-1,x.shape[1],1,1)
out2 = (x-mean)/((var+1e-5)**0.5)

4.其他参数细节见此博客:

https://blog.csdn.net/weixin_39228381/article/details/107939602

你可能感兴趣的:(python,pytorch,开发语言)