LLM:RMSNorm

import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, dim:int, eps:float):
        super(LayerNorm, self).__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(self.dim))
        self.bias = nn.Parameter(torch.zeros(self.dim))

    def forward(self, x):
        mean = x.mean(dim = -1, keepdim = True)
        var = x.var(dim = -1, keepdim = True)

        x_norm = (x - xmean)*torch.rsqrt(var + self.eps)

        return self.weight*x_norm + self.bias



class RMSNorm(nn.Module):
    def __init__(self,dim :int, eps:float):
        super(RMSNorm, self).__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parame

你可能感兴趣的:(LLM,&,AIGC,&,VLP,LLM,RMSNorm)