pytorch中的函数定义如下:
torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)
函数参数说明如如下:
[N, C, H, W]
来说,这里的normalized_shape定义都是要和矩阵最后几个维度保持一致的,这里就是[C, H, W]
。对比数学公式,其中的 γ \gamma γ 和 β \beta β 的维度都是[C, H, W]
, x x x 和 y y y 的维度都是[N, C, H, W]
。1e-5
LayerNorm的数学公式定义如下:
Y = X − E [ X ] V a r [ X ] + ϵ ∗ γ + β \begin{align*} Y &= \frac{X - E[X]}{\sqrt{Var[X] + \epsilon}} * \gamma + \beta \end{align*} Y=Var[X]+ϵX−E[X]∗γ+β
pytorch使用示例,给定一个[N, C, H, W]
的矩阵,在[C, H, W]
维上进行LayerNorm操作:
>>> # Image Example
>>> N, C, H, W = 20, 5, 10, 10
>>> input = torch.randn(N, C, H, W)
>>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
>>> # as shown in the image below
>>> layer_norm = nn.LayerNorm([C, H, W])
>>> output = layer_norm(input)
为了方便推导,eps先忽略,输入为一维矩阵。对应LayerNorm的数学公式定义如下, 其中 x x x是由 [ x 1 , . . . , x i , . . . , x N ] [x_1, ...,x_i, ..., x_N] [x1,...,xi,...,xN]组成的一维向量, y y y是输出向量,维度跟 x x x一样; E [ x ] E[x] E[x]是期望,简写为 μ \mu μ; V a r [ x ] Var[x] Var[x]是方差【 1 N ∑ i = 1 N ( x i − μ ) 2 \frac{1}{N} \sum^N_{i=1}{(x_i-\mu)^2} N1∑i=1N(xi−μ)2】; 标准差【 V a r [ x ] \sqrt{Var[x]} Var[x]】简写为 σ \sigma σ。
y = x − E [ x ] V a r [ x ] ∗ γ + β = x − μ σ ∗ γ + β = x ^ ∗ γ + β μ = 1 N ∑ j = 1 N x j σ = ( 1 N ∑ j = 1 N ( x j − μ ) 2 ) 1 2 x ^ = x − μ σ \begin{align*} y &= \frac{x - E[x]}{\sqrt{Var[x]}} * \gamma + \beta \\ &= \frac{x - \mu}{\sigma} * \gamma + \beta \\ &= \hat{x} * \gamma + \beta \\ \\ \mu &= \frac{1}{N}\sum^N_{j=1}{x_j} \\ \\ \sigma &= \left( \frac{1}{N} \sum^N_{j=1}{(x_j-\mu)^2} \right)^{\frac{1}{2}} \\ \\ \hat{x} &= \frac{x-\mu}{\sigma} \\ \\ \end{align*} yμσx^=Var[x]x−E[x]∗γ+β=σx−μ∗γ+β=x^∗γ+β=N1j=1∑Nxj=(N1j=1∑N(xj−μ)2)21=σx−μ
这里有三个地方需要求梯度(即需要进行求导),分别是对参数gamma ( γ ) (\gamma) (γ)和beta ( β ) (\beta) (β), 以及输入x的求导, 即 ∂ l ∂ γ \frac{\partial{l}}{\partial{\gamma}} ∂γ∂l、 ∂ l ∂ β \frac{\partial{l}}{\partial{\beta}} ∂β∂l、 ∂ l ∂ x \frac{\partial{l}}{\partial{x}} ∂x∂l。同时在计算 ∂ l ∂ x \frac{\partial{l}}{\partial{x}} ∂x∂l 时会用到 ∂ μ ∂ x \frac{\partial{\mu}}{\partial{x}} ∂x∂μ, ∂ σ ∂ x \frac{\partial{\sigma}}{\partial{x}} ∂x∂σ, ∂ x ^ ∂ x \frac{\partial{\hat{x}}}{\partial{x}} ∂x∂x^。
∂ l ∂ γ i = ∂ l ∂ y i ∗ ∂ y i ∂ γ i = ∂ l ∂ y i ∗ x i − μ σ ∂ l ∂ β i = ∂ l ∂ y i ∗ ∂ y i ∂ β i = ∂ l ∂ y i ∗ 1 ∂ μ ∂ x i = 1 N ∂ σ ∂ x i = 1 2 ∗ ( 1 N ∑ j = 1 N ( x j − μ ) 2 ) − 1 2 ∗ ∂ ∂ x i ( 1 N ∑ j = 1 N ( x j − μ ) 2 ) = 1 2 ∗ σ − 1 ∗ ∂ ∂ x i ( 1 N ∑ j = 1 N ( x j − μ ) 2 ) = 1 2 ∗ σ − 1 ∗ 1 N ∗ 2 ∗ ( x i − μ ) = σ − 1 ∗ 1 N ∗ ( x i − μ ) ∂ x ^ ∂ x i = ∂ ( x j − μ ) ∂ x i ∗ σ − 1 + ( x j − μ ) ∗ ( − 1 ) ∗ σ − 2 ∗ ∂ σ ∂ x i = σ − 1 ∗ ( δ i j − ∂ μ ∂ x i ) + σ − 2 ∗ ( x j − μ ) ∗ ( − 1 ) ∗ ∂ σ ∂ x i = σ − 1 ∗ δ i j + σ − 1 ∗ ( − 1 N ) + σ − 2 ∗ ( x j − μ ) ∗ ( − 1 ) ∗ ∂ σ ∂ x i = σ − 1 ∗ δ i j + σ − 1 ∗ ( − 1 N ) + σ − 3 ∗ 1 N ∗ ( x j − μ ) ∗ ( x i − μ ) ∗ ( − 1 ) [ 当 i 和 j 相等时, δ i j = 1 ,否则 δ i j = 0 ] ∂ l ∂ x i = ∑ j = 1 N ∂ l ∂ y j ∗ ∂ y j ∂ x i = ∑ j = 1 N ∂ l ∂ y j ∗ ∂ y j ∂ x j ^ ∗ ∂ x j ^ ∂ x i = ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ [ σ − 1 ∗ δ i j + σ − 1 ∗ ( − 1 N ) + σ − 3 ∗ 1 N ∗ ( x j − μ ) ∗ ( x i − μ ) ∗ ( − 1 ) ] \begin{align*} \frac{\partial{l}}{\partial{\gamma_i}} &= \frac{\partial{l}}{\partial{y_i}} * \frac{\partial{y_i}}{\partial{\gamma_i}} \\ &= \frac{\partial{l}}{\partial{y_i}} * \frac{x_i - \mu}{\sigma} \\ \\ \frac{\partial{l}}{\partial{\beta_i}} &= \frac{\partial{l}}{\partial{y_i}} * \frac{\partial{y_i}}{\partial{\beta_i}} \\ &= \frac{\partial{l}}{\partial{y_i}} * 1 \\ \\ \frac{\partial{\mu}}{\partial{x_i}} &= \frac{1}{N} \\ \\ \\ \frac{\partial{\sigma}}{\partial{x_i}} &= \frac{1}{2} * \left( \frac{1}{N} \sum^N_{j=1}{(x_j-\mu)^2} \right)^{-\frac{1}{2}} * \frac{\partial{}}{\partial{x_i}} \left( \frac{1}{N} \sum^N_{j=1}{(x_j-\mu)^2} \right) \\ &= \frac{1}{2} * \sigma^{-1} * \frac{\partial{}}{\partial{x_i}} \left( \frac{1}{N} \sum^N_{j=1}{(x_j-\mu)^2} \right) \\ &= \frac{1}{2} * \sigma^{-1} * \frac{1}{N} * 2 * (x_i - \mu) \\ &= \sigma^{-1} * \frac{1}{N} * (x_i - \mu) \\ \\ \frac{\partial{\hat{x}}}{\partial{x_i}} &= \frac{\partial{(x_j - \mu)}}{\partial{x_i}} * \sigma^{-1} + (x_j - \mu) * (-1) * \sigma^{-2} * \frac{\partial{\sigma}}{\partial{x_i}} \\ &= \sigma^{-1} * (\delta_{ij} - \frac{\partial{\mu}}{\partial{x_i}}) + \sigma^{-2} * (x_j - \mu) * (-1) * \frac{\partial{\sigma}}{\partial{x_i}} \\ &= \sigma^{-1} * \delta_{ij} + \sigma^{-1} * (- \frac{1}{N}) + \sigma^{-2} * (x_j - \mu) * (-1) * \frac{\partial{\sigma}}{\partial{x_i}} \\ &= \sigma^{-1} * \delta_{ij} + \sigma^{-1} * (- \frac{1}{N}) + \sigma^{-3} * \frac{1}{N} * (x_j - \mu) * (x_i - \mu) * (-1) \\ &[当i和j相等时,\delta_{ij}=1,否则\delta_{ij}=0] \\ \\ \frac{\partial{l}}{\partial{x_i}} &= \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \frac{\partial{y_j}}{\partial{x_i}} \\ &= \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \frac{\partial{y_j}}{\partial{\hat{x_j}}} * \frac{\partial{\hat{x_j}}}{\partial{x_i}} \\ &= \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \left[ \sigma^{-1} * \delta_{ij} + \sigma^{-1} * (- \frac{1}{N}) + \sigma^{-3} * \frac{1}{N} * (x_j - \mu) * (x_i - \mu) * (-1) \right] \\ \end{align*} ∂γi∂l∂βi∂l∂xi∂μ∂xi∂σ∂xi∂x^∂xi∂l=∂yi∂l∗∂γi∂yi=∂yi∂l∗σxi−μ=∂yi∂l∗∂βi∂yi=∂yi∂l∗1=N1=21∗(N1j=1∑N(xj−μ)2)−21∗∂xi∂(N1j=1∑N(xj−μ)2)=21∗σ−1∗∂xi∂(N1j=1∑N(xj−μ)2)=21∗σ−1∗N1∗2∗(xi−μ)=σ−1∗N1∗(xi−μ)=∂xi∂(xj−μ)∗σ−1+(xj−μ)∗(−1)∗σ−2∗∂xi∂σ=σ−1∗(δij−∂xi∂μ)+σ−2∗(xj−μ)∗(−1)∗∂xi∂σ=σ−1∗δij+σ−1∗(−N1)+σ−2∗(xj−μ)∗(−1)∗∂xi∂σ=σ−1∗δij+σ−1∗(−N1)+σ−3∗N1∗(xj−μ)∗(xi−μ)∗(−1)[当i和j相等时,δij=1,否则δij=0]=j=1∑N∂yj∂l∗∂xi∂yj=j=1∑N∂yj∂l∗∂xj^∂yj∗∂xi∂xj^=j=1∑N∂yj∂l∗γj∗[σ−1∗δij+σ−1∗(−N1)+σ−3∗N1∗(xj−μ)∗(xi−μ)∗(−1)]
这里 γ i / β i \gamma_i/\beta_i γi/βi 与 x i x_i xi 是一一对应的, 所以不用累加;但对于 x i x_i xi 参与了所有 y y y 的计算,反向的时候计算梯度也需要对涉及的所有的 y i y_i yi 相关的梯度进行累加。
代码仓版本:https://github.com/pytorch/pytorch/tree/v2.0.1
在aten/src/ATen/native/native_functions.yaml
中的定义如下:
- func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
dispatch:
CPU: layer_norm_cpu
CUDA: layer_norm_cuda
MPS: layer_norm_mps
CompositeExplicitAutograd: math_native_layer_norm
NestedTensorCPU, NestedTensorCUDA: nested_layer_norm
autogen: native_layer_norm.out
tags: core
这里以layer_norm_cpu
的实现为例,layer_norm_cpu
定义在aten/src/ATen/native/layer_norm.cpp
中。
在layer_norm_cpu
的前向函数中,会根据input
和normalized_shape
进行shape的转换计算,从多维矩阵转为 M × N M \times N M×N的二维矩阵,比如input
的shape是[2, 3, 4, 5]
,normalized_shape
是[4, 5]
, 那么M=2*3=6, N=4*5=20
;同时还会进行weight(对应 g a m m a gamma gamma)和bias(对应 b e t a beta beta)矩阵的初始化。
std::tuple<Tensor, Tensor, Tensor> layer_norm_cpu(
const Tensor& input,
IntArrayRef normalized_shape, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */,
double eps) {
// weight和bias初始化
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
// 计算M和N
auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
auto M = M_N.first;
auto N = M_N.second;
auto X = input.expect_contiguous();
auto gamma = weight.expect_contiguous();
auto beta = bias.expect_contiguous();
// 初始化mean/rstd,维度是M个,每N个input的元素会计算一个mean和rstd
Tensor mean = at::empty({M}, X->options().dtype(dtype));
Tensor rstd = at::empty({M}, X->options().dtype(dtype));
// layer_norm_with_mean_rstd_out中会调用前向kernel(LayerNormKernel)
layer_norm_with_mean_rstd_out(Y, mean, rstd, *X, normalized_shape, *gamma, *beta, eps, M, N);
return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));
}
LayerNormKernel
定义在aten/src/ATen/native/cpu/layer_norm_kernel.cpp
中,实际的实现是LayerNormKernelImplInternal
, 定义如下:
template <typename T, typename T_ACC>
void LayerNormKernelImplInternal(
const Tensor& X,
const Tensor& gamma,
const Tensor& beta,
int64_t M,
int64_t N,
T_ACC eps,
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
...
}
在LayerNormKernelImplInternal
首先了解at::parallel_for
函数的使用,它的基本作用是对输入先进行分块,然后通过多线程进行并行处理,如下函数的定义是对[0, M]分成多段,分别调用匿名函数。
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {...})
回顾下前向计算过程:
y = x − E [ x ] V a r [ x ] + e p s ∗ γ + β = x − μ σ ∗ γ + β = ( x σ + − μ σ ) ∗ γ + β \begin{align*} y &= \frac{x - E[x]}{\sqrt{Var[x]+eps}} * \gamma + \beta \\ &= \frac{x - \mu}{\sigma} * \gamma + \beta \\ &= (\frac{x}{\sigma} + \frac{- \mu}{\sigma}) * \gamma + \beta \\ \end{align*} y=Var[x]+epsx−E[x]∗γ+β=σx−μ∗γ+β=(σx+σ−μ)∗γ+β
匿名函数逻辑中,对于 M * N
的矩阵,每次处理N
个元素进行LayerNorm操作。mean对应 μ \mu μ, rstd_val和scale对应 1 σ \frac{1}{\sigma} σ1, bias对应 − μ σ \frac{-\mu}{\sigma} σ−μ, 因此, y = ( x ∗ s c a l e + b i a s ) ∗ g a m m a + b e t a y=(x * scale + bias) * gamma + beta y=(x∗scale+bias)∗gamma+beta
for (const auto i : c10::irange(start, end)) {
const T* X_ptr = X_data + i * N;
T* Y_ptr = Y_data + i * N;
T mean_val;
T rstd_val;
// 1. 计算mean_val和rstd_val
std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N);
rstd_val = T(1) / std::sqrt(rstd_val + eps);
const T scale = rstd_val;
const T bias = -rstd_val * mean_val;
if (gamma_null || beta_null) {
for (const auto j : c10::irange(N)) {
const T gamma_v = gamma_null ? T(1) : gamma_data[j];
const T beta_v = beta_null ? T(0) : beta_data[j];
Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v;
}
} else {
// 2. 计算layer norm的前向公式
vec::map3<T>(
[scale, bias](Vec x, Vec gamma, Vec beta) {
return (x * Vec(scale) + Vec(bias)) * gamma + beta;
},
Y_ptr,
X_ptr,
gamma_data,
beta_data,
N);
}
if (!mean_null) {
mean_data[i] = mean_val;
}
if (!rstd_null) {
rstd_data[i] = rstd_val;
}
}
}
对于多维矩阵求反向,可以看成是M
个大小为N
的向量,以一个5维向量为例,向量维度为 [ M 1 , M 2 , C , H , W ] [M_1, M_2, C, H, W] [M1,M2,C,H,W],layer_norm的维度是 [ C , H , W ] [C, H, W] [C,H,W],对应的 M = M 1 ∗ M 2 M=M_1*M_2 M=M1∗M2, N = C ∗ H ∗ W N=C*H*W N=C∗H∗W
在aten/src/ATen/native/native_functions.yaml
中的定义如下:
- func: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
dispatch:
CPU: layer_norm_backward_cpu
CUDA: layer_norm_backward_cuda
MPS: layer_norm_backward_mps
autogen: native_layer_norm_backward.out
tags: core
这里以layer_norm_backward_cpu
的实现为例,layer_norm_backward_cpu
定义在aten/src/ATen/native/layer_norm.cpp
中。跟layer_norm_cpu
类似,在backward中初始化相关tensor,和进行kernel的调用。
std::tuple layer_norm_backward_cpu(
const Tensor& dY,
const Tensor& input,
IntArrayRef normalized_shape,
const Tensor& mean,
const Tensor& rstd,
const c10::optional& weight_opt /* optional */,
const c10::optional& bias_opt /* optional */,
std::array grad_input_mask) {
......
if (M > 0) {
LayerNormBackwardKernel(
kCPU, dY, *X, mean, rstd, *gamma, M, N, &dX, &dgamma, &dbeta);
}
return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
}
为了方便和后续pytorch源码实现中对应,对上面推导公式的最后结果中做下相应的展开,展开如下:
∂ l ∂ x i = σ − 1 ∗ ∂ l ∂ y i ∗ γ i + ( − 1 ) ∗ σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j + σ − 3 ∗ 1 N ∗ ( μ − x i ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( x j − μ ) = σ − 1 ∗ ∂ l ∂ y i ∗ γ i + ( − 1 ) ∗ σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j + σ − 3 ∗ 1 N ∗ μ ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( x j − μ ) + σ − 3 ∗ 1 N ∗ ( − x i ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( x j − μ ) = σ − 1 ∗ ∂ l ∂ y i ∗ γ i + ( − 1 ) ∗ σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j + σ − 3 ∗ 1 N ∗ ( − μ ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) + σ − 3 ∗ 1 N ∗ ( x i ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) = σ − 1 ∗ ∂ l ∂ y i ∗ γ i + ( − 1 ) ∗ σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j + σ − 3 ∗ 1 N ∗ ( − μ ) ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) + σ − 3 ∗ 1 N ∗ x i ∗ [ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ μ − ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j ] = γ i ∗ ∂ l ∂ y i ∗ σ − 1 + [ − σ − 3 ∗ 1 N ∗ μ ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) − σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ] + x i ∗ [ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ μ − ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j ] ∗ σ − 3 ∗ 1 N \begin{align*} \frac{\partial{l}}{\partial{x_i}} &= \sigma^{-1} * \frac{\partial{l}}{\partial{y_i}} * \gamma_i + (-1) * \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j + \sigma^{-3} * \frac{1}{N} * (\mu - x_i) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (x_j - \mu) \\ &= \sigma^{-1} * \frac{\partial{l}}{\partial{y_i}} * \gamma_i + (-1) * \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j + \sigma^{-3} * \frac{1}{N} * \mu * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (x_j - \mu) + \sigma^{-3} * \frac{1}{N} * (- x_i) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (x_j - \mu) \\ &= \sigma^{-1} * \frac{\partial{l}}{\partial{y_i}} * \gamma_i + (-1) * \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j + \sigma^{-3} * \frac{1}{N} * (-\mu) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) + \sigma^{-3} * \frac{1}{N} * (x_i) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) \\ &= \sigma^{-1} * \frac{\partial{l}}{\partial{y_i}} * \gamma_i + (-1) * \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j + \sigma^{-3} * \frac{1}{N} * (-\mu) * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) + \sigma^{-3} * \frac{1}{N} * x_i * \left[ \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \mu - \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j \right] \\ &= \gamma_i * \frac{\partial{l}}{\partial{y_i}} * \sigma^{-1} + \left[ -\sigma^{-3} * \frac{1}{N} * \mu * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) - \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j \right] + x_i * \left[ \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \mu - \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j \right] * \sigma^{-3} * \frac{1}{N} \\ \end{align*} ∂xi∂l=σ−1∗∂yi∂l∗γi+(−1)∗σ−1∗N1∗j=1∑N∂yj∂l∗γj+σ−3∗N1∗(μ−xi)∗j=1∑N∂yj∂l∗γj∗(xj−μ)=σ−1∗∂yi∂l∗γi+(−1)∗σ−1∗N1∗j=1∑N∂yj∂l∗γj+σ−3∗N1∗μ∗j=1∑N∂yj∂l∗γj∗(xj−μ)+σ−3∗N1∗(−xi)∗j=1∑N∂yj∂l∗γj∗(xj−μ)=σ−1∗∂yi∂l∗γi+(−1)∗σ−1∗N1∗j=1∑N∂yj∂l∗γj+σ−3∗N1∗(−μ)∗j=1∑N∂yj∂l∗γj∗(μ−xj)+σ−3∗N1∗(xi)∗j=1∑N∂yj∂l∗γj∗(μ−xj)=σ−1∗∂yi∂l∗γi+(−1)∗σ−1∗N1∗j=1∑N∂yj∂l∗γj+σ−3∗N1∗(−μ)∗j=1∑N∂yj∂l∗γj∗(μ−xj)+σ−3∗N1∗xi∗[j=1∑N∂yj∂l∗γj∗μ−j=1∑N∂yj∂l∗γj∗xj]=γi∗∂yi∂l∗σ−1+[−σ−3∗N1∗μ∗j=1∑N∂yj∂l∗γj∗(μ−xj)−σ−1∗N1∗j=1∑N∂yj∂l∗γj]+xi∗[j=1∑N∂yj∂l∗γj∗μ−j=1∑N∂yj∂l∗γj∗xj]∗σ−3∗N1
kernel的实现在是aten/src/ATen/native/cpu/layer_norm_kernel.cpp
文件的LayerNormBackwardKernelImplInternal
函数中,实现分为两个阶段:
{2, max_threads, N}
的buffer矩阵,对应其中的buffer[0]
用于dgamma_buffer
, buffer[1]
用于dbeta_buffer
。多线程分别计算dY
和X
。dgamma/dbeta
的值进行累加操作,复用X[i]
和dY[i]
对于代码实现是通过两层嵌套进行的,对于第一步来说,最外面是对 M ∗ N M*N M∗N 的矩阵按行进行多线程并行,每个线程处理 m i ∗ N m_i*N mi∗N 个元素;第二步是按N列进行元素的累加。layer_norm_backward_frame
函数中包含了主要的计算逻辑, 后面进一步分析。
template <typename T>
void LayerNormBackwardKernelImplInternal(
const Tensor& dY,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const Tensor& gamma,
int64_t M,
int64_t N,
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
......
// 第一步:计算dgamma/dbeta and dX
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(
tid < num_threads,
"expect thread id smaller than ",
num_threads,
", got thread id ",
tid);
T* dgamma_buffer_ptr = dgamma_null ? nullptr : buffer_data + tid * N;
T* dbeta_buffer_ptr =
dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
for (const auto i : c10::irange(start, end)) {
layer_norm_backward_frame<T, T2, T_ACC>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
}
});
// 第二步:计算dgamma/dbeta的累加
if (buffer_data != nullptr) {
parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
for (const auto j : c10::irange(start, end)) {
T_ACC dgamma_v = T_ACC(0);
T_ACC dbeta_v = T_ACC(0);
for (const auto i : c10::irange(num_threads)) {
dgamma_v += buffer_data[i * N + j];
dbeta_v += buffer_data[num_threads * N + i * N + j];
}
if (!dgamma_null) {
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
dgamma_data[j] = dgamma_v;
}
if (!dbeta_null) {
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
dbeta_data[j] = dbeta_v;
}
}
});
}
......
}
layer_norm_backward_frame
函数中计算dgamma
的逻辑如下,对应公式: ∂ l ∂ γ i = ∂ l ∂ y i ∗ x i − μ σ \frac{\partial{l}}{\partial{\gamma_i}} = \frac{\partial{l}}{\partial{y_i}} * \frac{x_i - \mu}{\sigma} ∂γi∂l=∂yi∂l∗σxi−μ, 其中 a = 1 σ a=\frac{1}{\sigma} a=σ1, b = − μ σ = − a ∗ μ b=\frac{-\mu}{\sigma}=-a*\mu b=σ−μ=−a∗μ。
if (!dgamma_null) {
const T_ACC a = rstd_data[i];
const T_ACC b = -a * mean_data[i];
// Scalar math:
// for (const auto j : c10::irange(N)) {
// dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
// }
vec::map3<T>(
[a, b](Vec dgamma, Vec dy, Vec x) {
return dgamma + dy * (Vec(a) * x + Vec(b));
},
dgamma_buffer_ptr,
dgamma_buffer_ptr,
dY_ptr,
X_ptr,
N);
}
layer_norm_backward_frame
函数中计算dbeta
的逻辑如下,对应公式: ∂ l ∂ β i = ∂ l ∂ y i \frac{\partial{l}}{\partial{\beta_i}}= \frac{\partial{l}}{\partial{y_i}} ∂βi∂l=∂yi∂l。
if (!dbeta_null) {
// Scalar math:
// for (const auto j : c10::irange(N)) {
// dbeta_data[j] += dY_ptr[j];
// }
vec::map2<T>(
[](Vec dbeta, Vec dy) { return dbeta + dy; },
dbeta_buffer_ptr,
dbeta_buffer_ptr,
dY_ptr,
N);
}
layer_norm_backward_frame
函数中计算dx
的逻辑如下,对应公式:
∂ l ∂ x i = γ i ∗ ∂ l ∂ y i ∗ σ − 1 + [ − σ − 3 ∗ 1 N ∗ μ ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) − σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ] + x i ∗ [ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ μ − ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j ] ∗ σ − 3 ∗ 1 N \begin{align*} \frac{\partial{l}}{\partial{x_i}} &= \gamma_i * \frac{\partial{l}}{\partial{y_i}} * \sigma^{-1} + \left[ -\sigma^{-3} * \frac{1}{N} * \mu * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) - \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j \right] + x_i * \left[ \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \mu - \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j \right] * \sigma^{-3} * \frac{1}{N} \\ \end{align*} ∂xi∂l=γi∗∂yi∂l∗σ−1+[−σ−3∗N1∗μ∗j=1∑N∂yj∂l∗γj∗(μ−xj)−σ−1∗N1∗j=1∑N∂yj∂l∗γj]+xi∗[j=1∑N∂yj∂l∗γj∗μ−j=1∑N∂yj∂l∗γj∗xj]∗σ−3∗N1
layer_norm_backward_frame
函数核心代码实现如下:
if (gamma_null) {
......
} else {
ds = vec::map3_reduce_all<T>(
[](Vec x, Vec y, Vec z) { return x * y * z; },
[](Vec x, Vec y) { return x + y; },
dY_ptr,
X_ptr,
gamma_data,
N);
db = vec::map2_reduce_all<T>(
[](Vec x, Vec y) { return x * y; },
[](Vec x, Vec y) { return x + y; },
dY_ptr,
gamma_data,
N);
}
const T_ACC a = rstd_data[i];
const T_ACC b = (db * mean_data[i] - ds) * a * a * a * scale;
const T_ACC c = -b * mean_data[i] - db * a * scale;
if (gamma_null) {
......
} else {
vec::map3<T>(
[a, b, c](Vec dy, Vec gamma, Vec x) {
return Vec(a) * dy * gamma + Vec(b) * x + Vec(c);
},
dX_ptr,
dY_ptr,
gamma_data,
X_ptr,
N);
}
}
代码中变量与公式对应关系如下:
ds
对应 ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j ∑j=1N∂yj∂l∗γj∗xjdb
对应 ∑ j = 1 N ∂ l ∂ y j ∗ γ j \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j ∑j=1N∂yj∂l∗γja
对应 σ − 1 \sigma^{-1} σ−1scale
对应 1 N \frac{1}{N} N1b
对应 [ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ μ − ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ x j ] ∗ σ − 3 ∗ 1 N = ( d b ∗ μ − d s ) ∗ a ∗ a ∗ a ∗ s c a l e \left[ \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * \mu - \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * x_j \right] * \sigma^{-3} * \frac{1}{N} = (db * \mu - ds) * a * a * a * scale [∑j=1N∂yj∂l∗γj∗μ−∑j=1N∂yj∂l∗γj∗xj]∗σ−3∗N1=(db∗μ−ds)∗a∗a∗a∗scalec
对应 [ − σ − 3 ∗ 1 N ∗ μ ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ∗ ( μ − x j ) − σ − 1 ∗ 1 N ∗ ∑ j = 1 N ∂ l ∂ y j ∗ γ j ] = − b ∗ μ − d b ∗ a ∗ s c a l e \left[ -\sigma^{-3} * \frac{1}{N} * \mu * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j * (\mu - x_j) - \sigma^{-1} * \frac{1}{N} * \sum_{j=1}^N \frac{\partial{l}}{\partial{y_j}} * \gamma_j \right]=-b * \mu - db * a * scale [−σ−3∗N1∗μ∗∑j=1N∂yj∂l∗γj∗(μ−xj)−σ−1∗N1∗∑j=1N∂yj∂l∗γj]=−b∗μ−db∗a∗scaledx = Vec(a) * dy * gamma + Vec(b) * x + Vec(c)