Pytorch LayerNorm源码详解

1. LayerNorm使用介绍

pytorch中的函数定义如下:

torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)

函数参数说明如如下:

  • normalized_shape: 进行LayerNorm的维度定义,对于一个多维矩阵[N, C, H, W]来说,这里的normalized_shape定义都是要和矩阵最后几个维度保持一致的,这里就是[C, H, W]。对比数学公式,其中的 γ \gamma γ β \beta β 的维度都是[C, H, W] x x x y y y 的维度都是[N, C, H, W]
  • eps:为了防止计算公式中的分母为0,加上一个极小的数,默认值: 1e-5
  • elementwise_affine:设为True的时候,进行elementwise的仿射变换, γ \gamma γ β \beta β 才会生效,在训练过程中做为参数会被学习更新,为False的话不生效。 γ \gamma γ 所有元素初始为1, β \beta β 所有元素初始为0的。 γ \gamma γ 在代码实现中对应 g a m m a gamma gamma, β \beta β 在代码实现中对应 b e t a beta beta

LayerNorm的数学公式定义如下:

Y = X − E [ X ] V a r [ X ] + ϵ ∗ γ + β \begin{align*} Y &= \frac{X - E[X]}{\sqrt{Var[X] + \epsilon}} * \gamma + \beta \end{align*} Y=Var[X]+ϵ XE[X]γ+β

pytorch使用示例,给定一个[N, C, H, W]的矩阵,在[C, H, W]维上进行LayerNorm操作:

>>> # Image Example
>>> N, C, H, W = 20, 5, 10, 10
>>> input = torch.randn(N, C, H, W)
>>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
>>> # as shown in the image below
>>> layer_norm = nn.LayerNorm([C, H, W])
>>> output = layer_norm(input)

Pytorch LayerNorm源码详解_第1张图片

2. LayerNorm反向推导公式

为了方便推导,eps先忽略,输入为一维矩阵。对应LayerNorm的数学公式定义如下, 其中 x x x是由 [ x 1 , . . . , x i , . . . , x N ] [x_1, ...,x_i, ..., x_N] [x1,...,xi,...,xN]组成的一维向量, y y y是输出向量,维度跟 x x x一样; E [ x ] E[x] E[x]是期望,简写为 μ \mu μ; V a r [ x ] Var[x] Var[x]是方差【 1 N ∑ i = 1 N ( x i − μ ) 2 \frac{1}{N} \sum^N_{i=1}{(x_i-\mu)^2} N1i=1N(xiμ)2】; 标准差【 V a r [ x ] \sqrt{Var[x]} Var[x] 】简写为 σ \sigma σ

y = x − E [ x ] V a r [ x ] ∗ γ + β = x − μ σ ∗ γ + β = x ^ ∗ γ + β μ = 1 N ∑ j = 1 N x j σ = ( 1 N ∑ j = 1 N ( x j − μ ) 2 ) 1 2 x ^ = x − μ σ \begin{align*} y &= \frac{x - E[x]}{\sqrt{Var[x]}} * \gamma + \beta \\ &= \frac{x - \mu}{\sigma} * \gamma + \beta \\ &= \hat{x} * \gamma + \beta \\ \\ \mu &= \frac{1}{N}\sum^N_{j=1}{x_j} \\ \\ \sigma &= \left( \frac{1}{N} \sum^N_{j=1}{(x_j-\mu)^2} \right)^{\frac{1}{2}} \\ \\ \hat{x} &= \frac{x-\mu}{\sigma} \\ \\ \end{align*} yμσx^=Var[x] xE[x]γ+β=σxμγ+β=x^γ+β=N1j=1Nxj=(N1j=1N(xjμ)2)21=σxμ

这里有三个地方需要求梯度(即需要进行求导),分别是对参数gamma ( γ ) (\gamma) (γ)和beta ( β ) (\beta) (β), 以及输入x的求导, 即 ∂ l ∂ γ \frac{\partial{l}}{\partial{\gamma}} γl ∂ l ∂ β \frac{\partial{l}}{\partial{\beta}} βl ∂ l ∂ x \frac{\partial{l}}{\partial{x}} xl。同时在计算 ∂ l ∂ x \frac{\partial{l}}{\partial{x}} xl 时会用到 ∂ μ ∂ x \frac{\partial{\mu}}{\partial{x}} xμ ∂ σ ∂ x \frac{\partial{\sigma}}{\partial{x}} xσ ∂ x ^ ∂ x \frac{\partial{\hat{x}}}{\partial{x}} xx^

∂ l ∂ γ i = ∂ l ∂ y i ∗ ∂ y i ∂ γ i = ∂ l ∂ y i ∗ x i − μ σ ∂ l ∂ β i = ∂ l ∂ y i ∗ ∂ y i ∂ β i = ∂ l ∂ y i ∗ 1 ∂ μ ∂ x i = 1 N ∂ σ ∂ x i = 1 2 ∗ ( 1 N ∑ j = 1 N ( x j − μ ) 2 ) − 1 2 ∗ ∂ ∂ x i ( 1 N ∑ j = 1 N ( x j − μ ) 2 ) = 1 2 ∗ σ − 1 ∗ ∂ ∂ x i ( 1 N ∑ j = 1 N ( x j − μ ) 2 ) = 1 2 ∗ σ − 1 ∗ 1 N ∗ 2 ∗ ( x i − μ ) = σ − 1 ∗ 1 N ∗ ( x i − μ ) ∂ x ^ ∂ x i = ∂ ( x j − μ ) ∂ x i ∗ σ − 1 + ( x j − μ ) ∗ ( − 1 ) ∗ σ − 2 ∗ ∂ σ ∂ x i = σ − 1 ∗ ( δ i j − ∂ μ ∂ x i ) + σ − 2 ∗ ( x j − μ ) ∗ ( − 1 ) ∗ ∂ σ ∂ x i = σ − 1 ∗ δ i j + σ − 1 ∗ ( − 1 N ) + σ − 2 ∗ ( x j − μ ) ∗ ( − 1 ) ∗ ∂ σ ∂ x i = σ − 1 ∗ δ i j + σ − 1 ∗ ( − 1 N ) + σ − 3 ∗ 1 N ∗ ( x j − μ ) ∗ ( x i − μ ) ∗ ( − 1 ) [ 当 i 和 j 相等时, δ i j = 1 ,否则 δ i j = 0 ] ∂ l ∂ x i = ∑ j = 1 N ∂ l ∂ y j ∗ ∂ y j ∂ x i = ∑ j = 1 N ∂ l ∂ y j ∗ ∂ y j ∂ x j ^ ∗ ∂ x j ^ ∂ x i = ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ [ σ − 1 ∗ δ i j + σ − 1 ∗ ( − 1 N ) + σ − 3 ∗ 1 N ∗ ( x j − μ ) ∗ ( x i − μ ) ∗ ( − 1 ) ] \begin{align*} \frac{\partial{l}}{\partial{\gamma_i}} &= \frac{\partial{l}}{\partial{y_i}} * \frac{\partial{y_i}}{\partial{\gamma_i}} \\ &= \frac{\partial{l}}{\partial{y_i}} * \frac{x_i - \mu}{\sigma} \\ \\ \frac{\partial{l}}{\partial{\beta_i}} &= \frac{\partial{l}}{\partial{y_i}} * \frac{\partial{y_i}}{\partial{\beta_i}} \\ &= \frac{\partial{l}}{\partial{y_i}} * 1 \\ \\ \frac{\partial{\mu}}{\partial{x_i}} &= \frac{1}{N} \\ \\ \\ \frac{\partial{\sigma}}{\partial{x_i}} &= \frac{1}{2} * \left( \frac{1}{N} \sum^N_{j=1}{(x_j-\mu)^2} \right)^{-\frac{1}{2}} * \frac{\partial{}}{\partial{x_i}} \left( \frac{1}{N} \sum^N_{j=1}{(x_j-\mu)^2} \right) \\ &= \frac{1}{2} * \sigma^{-1} * \frac{\partial{}}{\partial{x_i}} \left( \frac{1}{N} \sum^N_{j=1}{(x_j-\mu)^2} \right) \\ &= \frac{1}{2} * \sigma^{-1} * \frac{1}{N} * 2 * (x_i - \mu) \\ &= \sigma^{-1} * \frac{1}{N} * (x_i - \mu) \\ \\ \frac{\partial{\hat{x}}}{\partial{x_i}} &= \frac{\partial{(x_j - \mu)}}{\partial{x_i}} * \sigma^{-1} + (x_j - \mu) * (-1) * \sigma^{-2} * \frac{\partial{\sigma}}{\partial{x_i}} \\ &= \sigma^{-1} * (\delta_{ij} - \frac{\partial{\mu}}{\partial{x_i}}) + \sigma^{-2} * (x_j - \mu) * (-1) * \frac{\partial{\sigma}}{\partial{x_i}} \\ &= \sigma^{-1} * \delta_{ij} + \sigma^{-1} * (- \frac{1}{N}) + \sigma^{-2} * (x_j - \mu) * (-1) * \frac{\partial{\sigma}}{\partial{x_i}} \\ &= \sigma^{-1} * \delta_{ij} + \sigma^{-1} * (- \frac{1}{N}) + \sigma^{-3} * \frac{1}{N} * (x_j - \mu) * (x_i - \mu) * (-1) \\ &[当i和j相等时,\delta_{ij}=1,否则\delta_{ij}=0] \\ \\ \frac{\partial{l}}{\partial{x_i}} &= \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \frac{\partial{y_j}}{\partial{x_i}} \\ &= \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \frac{\partial{y_j}}{\partial{\hat{x_j}}} * \frac{\partial{\hat{x_j}}}{\partial{x_i}} \\ &= \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \left[ \sigma^{-1} * \delta_{ij} + \sigma^{-1} * (- \frac{1}{N}) + \sigma^{-3} * \frac{1}{N} * (x_j - \mu) * (x_i - \mu) * (-1) \right] \\ \end{align*} γilβilxiμxiσxix^xil=yilγiyi=yilσxiμ=yilβiyi=yil1=N1=21(N1j=1N(xjμ)2)21xi(N1j=1N(xjμ)2)=21σ1xi(N1j=1N(xjμ)2)=21σ1N12(xiμ)=σ1N1(xiμ)=xi(xjμ)σ1+(xjμ)(1)σ2xiσ=σ1(δijxiμ)+σ2(xjμ)(1)xiσ=σ1δij+σ1(N1)+σ2(xjμ)(1)xiσ=σ1δij+σ1(N1)+σ3N1(xjμ)(xiμ)(1)[ij相等时,δij=1,否则δij=0]=j=1Nyjlxiyj=j=1Nyjlxj^yjxixj^=j=1Nyjlγj[σ1δij+σ1(N1)+σ3N1(xjμ)(xiμ)(1)]

这里 γ i / β i \gamma_i/\beta_i γi/βi x i x_i xi 是一一对应的, 所以不用累加;但对于 x i x_i xi 参与了所有 y y y 的计算,反向的时候计算梯度也需要对涉及的所有的 y i y_i yi 相关的梯度进行累加。

3. 源码实现

代码仓版本:https://github.com/pytorch/pytorch/tree/v2.0.1

3.1 前向计算

aten/src/ATen/native/native_functions.yaml中的定义如下:

- func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
  dispatch:
    CPU: layer_norm_cpu
    CUDA: layer_norm_cuda
    MPS: layer_norm_mps
    CompositeExplicitAutograd: math_native_layer_norm
    NestedTensorCPU, NestedTensorCUDA: nested_layer_norm
  autogen: native_layer_norm.out
  tags: core

这里以layer_norm_cpu的实现为例,layer_norm_cpu定义在aten/src/ATen/native/layer_norm.cpp中。

layer_norm_cpu的前向函数中,会根据inputnormalized_shape进行shape的转换计算,从多维矩阵转为 M × N M \times N M×N的二维矩阵,比如input的shape是[2, 3, 4, 5]normalized_shape[4, 5], 那么M=2*3=6, N=4*5=20;同时还会进行weight(对应 g a m m a gamma gamma)和bias(对应 b e t a beta beta)矩阵的初始化。

std::tuple<Tensor, Tensor, Tensor> layer_norm_cpu(
    const Tensor& input,
    IntArrayRef normalized_shape, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */,
    double eps) {
  // weight和bias初始化
  c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
  const Tensor& weight = *weight_maybe_owned;
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  // 计算M和N
  auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
  auto M = M_N.first;
  auto N = M_N.second;
  auto X = input.expect_contiguous();
  auto gamma = weight.expect_contiguous();
  auto beta = bias.expect_contiguous();

  // 初始化mean/rstd,维度是M个,每N个input的元素会计算一个mean和rstd
  Tensor mean = at::empty({M}, X->options().dtype(dtype));
  Tensor rstd = at::empty({M}, X->options().dtype(dtype));
  
  // layer_norm_with_mean_rstd_out中会调用前向kernel(LayerNormKernel)
  layer_norm_with_mean_rstd_out(Y, mean, rstd, *X, normalized_shape, *gamma, *beta, eps, M, N);
  return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));
}

LayerNormKernel定义在aten/src/ATen/native/cpu/layer_norm_kernel.cpp中,实际的实现是LayerNormKernelImplInternal, 定义如下:

template <typename T, typename T_ACC>
void LayerNormKernelImplInternal(
    const Tensor& X,
    const Tensor& gamma,
    const Tensor& beta,
    int64_t M,
    int64_t N,
    T_ACC eps,
    Tensor* Y,
    Tensor* mean,
    Tensor* rstd) {
    ...   
}

LayerNormKernelImplInternal首先了解at::parallel_for函数的使用,它的基本作用是对输入先进行分块,然后通过多线程进行并行处理,如下函数的定义是对[0, M]分成多段,分别调用匿名函数。

  at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {...})

回顾下前向计算过程:

y = x − E [ x ] V a r [ x ] + e p s ∗ γ + β = x − μ σ ∗ γ + β = ( x σ + − μ σ ) ∗ γ + β \begin{align*} y &= \frac{x - E[x]}{\sqrt{Var[x]+eps}} * \gamma + \beta \\ &= \frac{x - \mu}{\sigma} * \gamma + \beta \\ &= (\frac{x}{\sigma} + \frac{- \mu}{\sigma}) * \gamma + \beta \\ \end{align*} y=Var[x]+eps xE[x]γ+β=σxμγ+β=(σx+σμ)γ+β

匿名函数逻辑中,对于 M * N的矩阵,每次处理N个元素进行LayerNorm操作。mean对应 μ \mu μ, rstd_val和scale对应 1 σ \frac{1}{\sigma} σ1, bias对应 − μ σ \frac{-\mu}{\sigma} σμ, 因此, y = ( x ∗ s c a l e + b i a s ) ∗ g a m m a + b e t a y=(x * scale + bias) * gamma + beta y=(xscale+bias)gamma+beta

    for (const auto i : c10::irange(start, end)) {
      const T* X_ptr = X_data + i * N;
      T* Y_ptr = Y_data + i * N;
      T mean_val;
      T rstd_val;
      // 1. 计算mean_val和rstd_val
      std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N);
      rstd_val = T(1) / std::sqrt(rstd_val + eps);
      
      const T scale = rstd_val;
      const T bias = -rstd_val * mean_val;
      if (gamma_null || beta_null) {
        for (const auto j : c10::irange(N)) {
          const T gamma_v = gamma_null ? T(1) : gamma_data[j];
          const T beta_v = beta_null ? T(0) : beta_data[j];
          Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v;
        }
      } else {
        // 2. 计算layer norm的前向公式
        vec::map3<T>(
            [scale, bias](Vec x, Vec gamma, Vec beta) {
              return (x * Vec(scale) + Vec(bias)) * gamma + beta;
            },
            Y_ptr,
            X_ptr,
            gamma_data,
            beta_data,
            N);
      }
      if (!mean_null) {
        mean_data[i] = mean_val;
      }
      if (!rstd_null) {
        rstd_data[i] = rstd_val;
      }
    }
  }

3.2 反向计算

对于多维矩阵求反向,可以看成是M个大小为N的向量,以一个5维向量为例,向量维度为 [ M 1 , M 2 , C , H , W ] [M_1, M_2, C, H, W] [M1,M2,C,H,W],layer_norm的维度是 [ C , H , W ] [C, H, W] [C,H,W],对应的 M = M 1 ∗ M 2 M=M_1*M_2 M=M1M2, N = C ∗ H ∗ W N=C*H*W N=CHW

aten/src/ATen/native/native_functions.yaml中的定义如下:

- func: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
  dispatch:
    CPU: layer_norm_backward_cpu
    CUDA: layer_norm_backward_cuda
    MPS: layer_norm_backward_mps
  autogen: native_layer_norm_backward.out
  tags: core

这里以layer_norm_backward_cpu的实现为例,layer_norm_backward_cpu定义在aten/src/ATen/native/layer_norm.cpp中。跟layer_norm_cpu类似,在backward中初始化相关tensor,和进行kernel的调用。

std::tuple layer_norm_backward_cpu(
    const Tensor& dY,
    const Tensor& input,
    IntArrayRef normalized_shape,
    const Tensor& mean,
    const Tensor& rstd,
    const c10::optional& weight_opt /* optional */,
    const c10::optional& bias_opt /* optional */,
    std::array grad_input_mask) {
  ......
  if (M > 0) {
    LayerNormBackwardKernel(
        kCPU, dY, *X, mean, rstd, *gamma, M, N, &dX, &dgamma, &dbeta);
  }
  return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));   
}

为了方便和后续pytorch源码实现中对应,对上面推导公式的最后结果中做下相应的展开,展开如下:
∂ l ∂ x i = σ − 1 ∗ ∂ l ∂ y i ∗ γ i + ( − 1 ) ∗ σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j + σ − 3 ∗ 1 N ∗ ( μ − x i ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( x j − μ ) = σ − 1 ∗ ∂ l ∂ y i ∗ γ i + ( − 1 ) ∗ σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j + σ − 3 ∗ 1 N ∗ μ ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( x j − μ ) + σ − 3 ∗ 1 N ∗ ( − x i ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( x j − μ ) = σ − 1 ∗ ∂ l ∂ y i ∗ γ i + ( − 1 ) ∗ σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j + σ − 3 ∗ 1 N ∗ ( − μ ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) + σ − 3 ∗ 1 N ∗ ( x i ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) = σ − 1 ∗ ∂ l ∂ y i ∗ γ i + ( − 1 ) ∗ σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j + σ − 3 ∗ 1 N ∗ ( − μ ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) + σ − 3 ∗ 1 N ∗ x i ∗ [ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ μ − ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j ] = γ i ∗ ∂ l ∂ y i ∗ σ − 1 + [ − σ − 3 ∗ 1 N ∗ μ ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) − σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ] + x i ∗ [ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ μ − ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j ] ∗ σ − 3 ∗ 1 N \begin{align*} \frac{\partial{l}}{\partial{x_i}} &= \sigma^{-1} * \frac{\partial{l}}{\partial{y_i}} * \gamma_i + (-1) * \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j + \sigma^{-3} * \frac{1}{N} * (\mu - x_i) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (x_j - \mu) \\ &= \sigma^{-1} * \frac{\partial{l}}{\partial{y_i}} * \gamma_i + (-1) * \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j + \sigma^{-3} * \frac{1}{N} * \mu * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (x_j - \mu) + \sigma^{-3} * \frac{1}{N} * (- x_i) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (x_j - \mu) \\ &= \sigma^{-1} * \frac{\partial{l}}{\partial{y_i}} * \gamma_i + (-1) * \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j + \sigma^{-3} * \frac{1}{N} * (-\mu) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) + \sigma^{-3} * \frac{1}{N} * (x_i) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) \\ &= \sigma^{-1} * \frac{\partial{l}}{\partial{y_i}} * \gamma_i + (-1) * \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j + \sigma^{-3} * \frac{1}{N} * (-\mu) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) + \sigma^{-3} * \frac{1}{N} * x_i * \left[ \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \mu - \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j \right] \\ &= \gamma_i * \frac{\partial{l}}{\partial{y_i}} * \sigma^{-1} + \left[ -\sigma^{-3} * \frac{1}{N} * \mu * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) - \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j \right] + x_i * \left[ \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \mu - \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j \right] * \sigma^{-3} * \frac{1}{N} \\ \end{align*} xil=σ1yilγi+(1)σ1N1j=1Nyjlγj+σ3N1(μxi)j=1Nyjlγj(xjμ)=σ1yilγi+(1)σ1N1j=1Nyjlγj+σ3N1μj=1Nyjlγj(xjμ)+σ3N1(xi)j=1Nyjlγj(xjμ)=σ1yilγi+(1)σ1N1j=1Nyjlγj+σ3N1(μ)j=1Nyjlγj(μxj)+σ3N1(xi)j=1Nyjlγj(μxj)=σ1yilγi+(1)σ1N1j=1Nyjlγj+σ3N1(μ)j=1Nyjlγj(μxj)+σ3N1xi[j=1Nyjlγjμj=1Nyjlγjxj]=γiyilσ1+[σ3N1μj=1Nyjlγj(μxj)σ1N1j=1Nyjlγj]+xi[j=1Nyjlγjμj=1Nyjlγjxj]σ3N1

kernel的实现在是aten/src/ATen/native/cpu/layer_norm_kernel.cpp文件的LayerNormBackwardKernelImplInternal函数中,实现分为两个阶段:

  1. 初始化一个shape大小为{2, max_threads, N}的buffer矩阵,对应其中的buffer[0]用于dgamma_buffer, buffer[1]用于dbeta_buffer。多线程分别计算dYX
  2. dgamma/dbeta的值进行累加操作,复用X[i]dY[i]

对于代码实现是通过两层嵌套进行的,对于第一步来说,最外面是对 M ∗ N M*N MN 的矩阵按行进行多线程并行,每个线程处理 m i ∗ N m_i*N miN 个元素;第二步是按N列进行元素的累加。layer_norm_backward_frame函数中包含了主要的计算逻辑, 后面进一步分析。

template <typename T>
void LayerNormBackwardKernelImplInternal(
    const Tensor& dY,
    const Tensor& X,
    const Tensor& mean,
    const Tensor& rstd,
    const Tensor& gamma,
    int64_t M,
    int64_t N,
    Tensor* dX,
    Tensor* dgamma,
    Tensor* dbeta) {
  ......
  // 第一步:计算dgamma/dbeta and dX
  at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
    int tid = at::get_thread_num();
    TORCH_CHECK(
        tid < num_threads,
        "expect thread id smaller than ",
        num_threads,
        ", got thread id ",
        tid);
    T* dgamma_buffer_ptr = dgamma_null ? nullptr : buffer_data + tid * N;
    T* dbeta_buffer_ptr =
        dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
    for (const auto i : c10::irange(start, end)) {
      layer_norm_backward_frame<T, T2, T_ACC>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
    }
  });

  // 第二步:计算dgamma/dbeta的累加
  if (buffer_data != nullptr) {
    parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
      for (const auto j : c10::irange(start, end)) {
        T_ACC dgamma_v = T_ACC(0);
        T_ACC dbeta_v = T_ACC(0);
        for (const auto i : c10::irange(num_threads)) {
          dgamma_v += buffer_data[i * N + j];
          dbeta_v += buffer_data[num_threads * N + i * N + j];
        }
        if (!dgamma_null) {
          // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
          dgamma_data[j] = dgamma_v;
        }
        if (!dbeta_null) {
          // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
          dbeta_data[j] = dbeta_v;
        }
      }
    });
  }
  ......
}

layer_norm_backward_frame函数中计算dgamma的逻辑如下,对应公式: ∂ l ∂ γ i = ∂ l ∂ y i ∗ x i − μ σ \frac{\partial{l}}{\partial{\gamma_i}} = \frac{\partial{l}}{\partial{y_i}} * \frac{x_i - \mu}{\sigma} γil=yilσxiμ, 其中 a = 1 σ a=\frac{1}{\sigma} a=σ1, b = − μ σ = − a ∗ μ b=\frac{-\mu}{\sigma}=-a*\mu b=σμ=aμ

  if (!dgamma_null) {
    const T_ACC a = rstd_data[i];
    const T_ACC b = -a * mean_data[i];
    // Scalar math:
    // for (const auto j : c10::irange(N)) {
    //   dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
    // }
    vec::map3<T>(
        [a, b](Vec dgamma, Vec dy, Vec x) {
          return dgamma + dy * (Vec(a) * x + Vec(b));
        },
        dgamma_buffer_ptr,
        dgamma_buffer_ptr,
        dY_ptr,
        X_ptr,
        N);
  }

layer_norm_backward_frame函数中计算dbeta的逻辑如下,对应公式: ∂ l ∂ β i = ∂ l ∂ y i \frac{\partial{l}}{\partial{\beta_i}}= \frac{\partial{l}}{\partial{y_i}} βil=yil

  if (!dbeta_null) {
    // Scalar math:
    // for (const auto j : c10::irange(N)) {
    //   dbeta_data[j] += dY_ptr[j];
    // }
    vec::map2<T>(
        [](Vec dbeta, Vec dy) { return dbeta + dy; },
        dbeta_buffer_ptr,
        dbeta_buffer_ptr,
        dY_ptr,
        N);
  }

layer_norm_backward_frame函数中计算dx的逻辑如下,对应公式:

∂ l ∂ x i = γ i ∗ ∂ l ∂ y i ∗ σ − 1 + [ − σ − 3 ∗ 1 N ∗ μ ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) − σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ] + x i ∗ [ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ μ − ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j ] ∗ σ − 3 ∗ 1 N \begin{align*} \frac{\partial{l}}{\partial{x_i}} &= \gamma_i * \frac{\partial{l}}{\partial{y_i}} * \sigma^{-1} + \left[ -\sigma^{-3} * \frac{1}{N} * \mu * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) - \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j \right] + x_i * \left[ \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \mu - \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j \right] * \sigma^{-3} * \frac{1}{N} \\ \end{align*} xil=γiyilσ1+[σ3N1μj=1Nyjlγj(μxj)σ1N1j=1Nyjlγj]+xi[j=1Nyjlγjμj=1Nyjlγjxj]σ3N1

layer_norm_backward_frame函数核心代码实现如下:

    if (gamma_null) {
      ......
    } else {
      ds = vec::map3_reduce_all<T>(
          [](Vec x, Vec y, Vec z) { return x * y * z; },
          [](Vec x, Vec y) { return x + y; },
          dY_ptr,
          X_ptr,
          gamma_data,
          N);
      db = vec::map2_reduce_all<T>(
          [](Vec x, Vec y) { return x * y; },
          [](Vec x, Vec y) { return x + y; },
          dY_ptr,
          gamma_data,
          N);
    }
    const T_ACC a = rstd_data[i];
    const T_ACC b = (db * mean_data[i] - ds) * a * a * a * scale;
    const T_ACC c = -b * mean_data[i] - db * a * scale;
    if (gamma_null) {
      ......
    } else {
      vec::map3<T>(
          [a, b, c](Vec dy, Vec gamma, Vec x) {
            return Vec(a) * dy * gamma + Vec(b) * x + Vec(c);
          },
          dX_ptr,
          dY_ptr,
          gamma_data,
          X_ptr,
          N);
    }
  }

代码中变量与公式对应关系如下:

  • ds对应 ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j j=1Nyjlγjxj
  • db对应 ∑ j = 1 N ∂ l ∂ y j ∗ γ j \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j j=1Nyjlγj
  • a对应 σ − 1 \sigma^{-1} σ1
  • scale对应 1 N \frac{1}{N} N1
  • b对应 [ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ μ − ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j ] ∗ σ − 3 ∗ 1 N = ( d b ∗ μ − d s ) ∗ a ∗ a ∗ a ∗ s c a l e \left[ \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \mu - \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j \right] * \sigma^{-3} * \frac{1}{N} = (db * \mu - ds) * a * a * a * scale [j=1Nyjlγjμj=1Nyjlγjxj]σ3N1=(dbμds)aaascale
  • c对应 [ − σ − 3 ∗ 1 N ∗ μ ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) − σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ] = − b ∗ μ − d b ∗ a ∗ s c a l e \left[ -\sigma^{-3} * \frac{1}{N} * \mu * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) - \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j \right]=-b * \mu - db * a * scale [σ3N1μj=1Nyjlγj(μxj)σ1N1j=1Nyjlγj]=bμdbascale
  • 最终结果:dx = Vec(a) * dy * gamma + Vec(b) * x + Vec(c)

4. 参考资料

  • Vector, Matrix, and Tensor Derivatives
  • 手推公式之“层归一化”梯度
  • 矩阵求导浅析(一)
  • 矩阵求导术(上)
  • 矩阵求导术(下)
  • 道理我都懂,但是神经网络反向传播时的梯度到底怎么求?
  • 神经网络反向传播的数学原理
  • layernorm 反向传播推导及代码
  • pytorch-LAYERNORM
  • Layer Normalization-paper

你可能感兴趣的:(训练框架,pytorch,人工智能,python,深度学习,机器学习)