RMSNorm均方根标准化

一、目录

1 定义
2 实现

二、实现

  1. 定义
    layer normalization 重要的两个部分是平移不变性和缩放不变性。 Root Mean Square Layer Normalization 认为 layer normalization 取得成功重要的是缩放不变性,而不是平移不变性。因此,去除了计算过程中的平移,只保留了缩放,进行了简化,提出了RMS Norm(Root Mean Square Layer Normalization),即均方根 norm。
    RMSNorm均方根标准化_第1张图片
    优点:训练速度更快,效果相当。
    2 实现
#均方根标准化
class RMSNorm(torch.nn.Module):
    def __init__(self,normalized_shape,eps=1e-5,devices=None,dtype=None,**kwargs):
        super().__init__()
        self.weight=torch.nn.Parameter(torch.empty(size=normalized_shape,device=devices,dtype=dtype))   #待训练的参数
        self.eps=eps
    def forward(self,hidden_state:torch.Tensor):
        input_type=hidden_state.dtype
        variace=hidden_state.to(torch.float32).pow(2).mean(-1,keepdim=True)
        hidden_state=hidden_state*torch.rsqrt(variace+self.eps)
        return (hidden_state*self.weight).to(input_type)


if __name__ == '__main__':
    x=RMSNorm(normalized_shape=[3,4])
    y=x(torch.randn(size=(3,4)))
    print(y)

你可能感兴趣的:(torch,深度学习,python,pytorch)