affine/linear(仿射/线性)变换函数详解及全连接层反向传播的梯度求导

摘要

Affine 仿射层, 又称 Linear 线性变换层, 常用于神经网络结构中的全连接层.
本文给出了 Affine 层的两种定义及相关的反向传播梯度.

相关

配套代码, 请参考文章 :

Python和PyTorch对比实现affine/linear(仿射/线性)变换函数及全连接层的反向传播

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. Affine 的一种定义

考虑一个输入向量 x, Affine 层的权重为 k 维向量 w, 偏置为标量 b, 则 :
x = ( x 1 , x 2 , x 3 , ⋯   , x k )    w = ( w 1 , w 2 , w 3 , ⋯   , w k )    a f f i n e ( x i , w i , b ) = x i w i + b x = (x_1,x_2,x_3,\cdots,x_k)\\ \;\\ w = (w_1, w_2,w_3,\cdots,w_k)\\ \;\\ affine(x_i,w_i,b) = x_iw_i+b x=(x1,x2,x3,,xk)w=(w1,w2,w3,,wk)affine(xi,wi,b)=xiwi+b

使用 X 表示 m 行 k 列的矩阵, 偏置为标量 b, 则一次仿射变换为 :
a T = a f f i n e ( X , w , b ) = X w T + b    a T = ( x 11 x 12 x 13 ⋯ x 1 k x 21 x 22 x 23 ⋯ x 2 k x 31 x 32 x 33 ⋯ x 3 k ⋮ ⋮ ⋮ ⋱ ⋮ x m 1 x m 2 x m 3 ⋯ x m k ) ( w 1 w 2 w 3 ⋮ w k ) + b    a = ( a 1 , a 2 , a 3 , ⋯   , a k ) a^T=affine(X,w,b) = Xw^T + b\\\;\\ a^T= \begin{pmatrix} x_{11}&x_{12} &x_{13}&\cdots&x_{1k}\\ x_{21}&x_{22}&x_{23}&\cdots&x_{2k}\\ x_{31}&x_{32}&x_{33}&\cdots&x_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ x_{m1}&x_{m2}&x_{m3}&\cdots&x_{mk} \end{pmatrix} \begin{pmatrix} w_1\\ w_2\\ w_3\\ \vdots\\ w_k \end{pmatrix} +b\\ \;\\ a= (a_1,a_2,a_3,\cdots,a_k) aT=affine(X,w,b)=XwT+baT=x11x21x31xm1x12x22x32xm2x13x23x33xm3x1kx2kx3kxmkw1w2w3wk+ba=(a1,a2,a3,,ak)

更一般的, 若使用 W 表示 n 行 k 列的矩阵, 偏置为向量 b , 则 n 次仿射变换为 :
W n × k = ( w 11 w 12 w 13 ⋯ w 1 k w 21 w 22 w 23 ⋯ w 2 k w 31 w 32 w 33 ⋯ w 3 k ⋮ ⋮ ⋮ ⋱ ⋮ w n 1 w n 2 w n 3 ⋯ w n k )    b 1 × n = ( b 1 , b 2 , b 3 , ⋯   , b n )    A m × n = a f f i n e ( X , W , b ) = X m × k W n × k T + b 1 × n W_{n\times k} =\begin{pmatrix} w_{11}&w_{12} &w_{13}&\cdots&w_{1k}\\ w_{21}&w_{22}&w_{23}&\cdots&w_{2k}\\ w_{31}&w_{32}&w_{33}&\cdots&w_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ w_{n1}&w_{n2}&w_{n3}&\cdots&w_{nk} \end{pmatrix}\\ \;\\ b_{1 \times n} = (b_1,b_2,b_3,\cdots,b_n)\\\;\\ A_{m\times n} = affine(X,W,b) = X_{m\times k}W^T_{n\times k} + b_{1 \times n} Wn×k=w11w21w31wn1w12w22w32wn2w13w23w33wn3w1kw2kw3kwnkb1×n=(b1,b2,b3,,bn)Am×n=affine(X,W,b)=Xm×kWn×kT+b1×n

使用求和符号表示 A 矩阵中的元素 :
a i j = ∑ t = 1 k x i t ⋅ w j t + b j a_{ij} =\sum_{t=1}^{k} x_{it} \cdot w_{jt} + b_j aij=t=1kxitwjt+bj

取其中一项展开作为示例 :
a 23 = ∑ t = 1 k x 2 t ⋅ w 3 t + b 3 = x 21 w 31 + x 22 w 32 + x 23 w 33 + ⋯ + x 2 k w 3 k + b 3 a_{23} =\sum_{t=1}^{k} x_{2t} \cdot w_{3t} + b_3= x_{21}w_{31}+x_{22}w_{32}+x_{23}w_{33}+\cdots+x_{2k}w_{3k}+ b_3 a23=t=1kx2tw3t+b3=x21w31+x22w32+x23w33++x2kw3k+b3

2. 梯度的定义

三维XYZ空间中的梯度定义:
∇ e ( 3 ) = ∂ e ∂ x i + ∂ e ∂ y j + ∂ e ∂ z k \nabla e_{(3)} = \frac{\partial e}{\partial x}i+\frac{\partial e}{\partial y}j+\frac{\partial e}{\partial z}k e(3)=xei+yej+zek

式中, i , j , k i, j, k i,j,k是三个两两相互垂直的单位向量, 或 i , j , k i, j, k i,j,k 是正交单位向量组, 或 i , j , k i, j, k i,j,k 是一组线性无关的单位向量, 这三种说法是等价的.

推广到 t t t 维向量空间 V V V, 若 t t t 个向量 I 1 , I 2 , I 3 , ⋯   , I t I_1, I_2, I_3,\cdots, I_t I1,I2,I3,,It 是一组两两正交的单位向量, 或单位向量组 I 1 , I 2 , I 3 , ⋯   , I t I_1, I_2, I_3,\cdots, I_t I1,I2,I3,,It 线性无关, 那么, 该向量空间 V V V 中的梯度可定义为 :
∇ e ( V ) = ∂ e ∂ x 1 I 1 + ∂ e ∂ x 2 I 2 + ∂ e ∂ x 3 I 3 + ⋯ + ∂ e ∂ x t I t \nabla e_{(V)} = \frac{\partial e}{\partial x_1}I_1+\frac{\partial e}{\partial x_2}I_2+\frac{\partial e}{\partial x_3}I_3+\cdots+\frac{\partial e}{\partial x_t}I_t e(V)=x1eI1+x2eI2+x3eI3++xteIt

梯度的定义可以在 <高等数学> 中找到, 正交和线性无关的定义可以在 <线性代数> 中找到.

3. 反向传播中的梯度求导

若 X 矩阵经过 affine 层变换得到 A 矩阵, 往前 forward 传播得到误差值 error (标量 e ), 求 e 关于 X 的梯度:
A m × n = X m × k W n × k T + b 1 × n    e = f o r w a r d ( A ) A_{m \times n} = X_{m\times k}{W_{n\times k}}^T + b_{1 \times n}\\ \;\\ e=forward(A) Am×n=Xm×kWn×kT+b1×ne=forward(A)

3.1 损失值 e 对 A 矩阵的梯度

首先, 我们说求梯度, 究竟是在求什么?
答 : 一个让损失值 e 变小的最快的方向.

比如, e 对 A 的梯度矩阵 :
d e d A = ( ∂ e / ∂ a 11 ∂ e / ∂ a 12 ∂ e / ∂ a 13 ⋯ ∂ e / ∂ a 1 n ∂ e / ∂ a 21 ∂ e / ∂ a 22 ∂ e / ∂ a 23 ⋯ ∂ e / ∂ a 2 n ∂ e / ∂ a 31 ∂ e / ∂ a 32 ∂ e / ∂ a 33 ⋯ ∂ e / ∂ a 3 n ⋮ ⋮ ⋮ ⋱ ⋮ ∂ e / ∂ a m 1 ∂ e / ∂ a m 2 ∂ e / ∂ a m 3 ⋯ ∂ e / ∂ a m n ) \frac{de}{dA} = \begin{pmatrix} \partial e/ \partial a_{11}&\partial e/ \partial a_{12}&\partial e/ \partial a_{13}&\cdots& \partial e/ \partial a_{1n}\\ \partial e/ \partial a_{21}&\partial e/ \partial a_{22}&\partial e/ \partial a_{23}&\cdots& \partial e/ \partial a_{2n}\\ \partial e/ \partial a_{31}&\partial e/ \partial a_{32}&\partial e/ \partial a_{33}&\cdots& \partial e/ \partial a_{3n}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \partial e/ \partial a_{m1}&\partial e/ \partial a_{m2}&\partial e/ \partial a_{m3}&\cdots& \partial e/ \partial a_{mn}\\ \end{pmatrix} dAde=e/a11e/a21e/a31e/am1e/a12e/a22e/a32e/am2e/a13e/a23e/a33e/am3e/a1ne/a2ne/a3ne/amn

为了书写方便, 记 :
∂ e ∂ a i j = a i j ′    ∇ e ( A ) = d e d A = ( a 11 ′ a 12 ′ a 13 ′ ⋯ a 1 n ′ a 21 ′ a 22 ′ a 23 ′ ⋯ a 2 n ′ a 31 ′ a 32 ′ a 33 ′ ⋯ a 3 n ′ ⋮ ⋮ ⋮ ⋱ ⋮ a m 1 ′ a m 2 ′ a m 3 ′ ⋯ a m n ′ ) \frac{\partial e}{\partial a_{ij}} = a_{ij}'\\ \;\\ \nabla e_{(A)}= \frac{de}{dA} = \begin{pmatrix} a_{11}'& a_{12}'& a_{13}'&\cdots& a_{1n}'\\ a_{21}'& a_{22}'& a_{23}'&\cdots& a_{2n}'\\ a_{31}'& a_{32}'& a_{33}'&\cdots& a_{3n}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{m1}'& a_{m2}'& a_{m3}'&\cdots& a_{mn}' \end{pmatrix} aije=aije(A)=dAde=a11a21a31am1a12a22a32am2a13a23a33am3a1na2na3namn

所有的 a i j ′ a_{ij}' aij 都是已知的, 是上游的 forward 函数帮我们算好的.
只要矩阵 A 中所有的元素按照这个矩阵等比例的更新, 那么就是使 e 值减少最快的方向.
梯度本身的定义并不是一个矩阵, 而是一个向量 :
∇ e ( A ) = ( a 11 ′ , a 12 ′ , ⋯   , a 21 ′ , a 22 ′ , ⋯   , a m 1 ′ , a m 2 ′ , ⋯   , a m n ′ ) \nabla e_{(A)}= (a_{11}', a_{12}',\cdots, a_{21}', a_{22}',\cdots,a_{m1}', a_{m2}',\cdots, a_{mn}') e(A)=(a11,a12,,a21,a22,,am1,am2,,amn)

这个写法和上面的矩阵写法是等价的.
利用矩阵求导的写法求梯度, 求的是方向导数, 或者单位向量的系数, 和普通的矩阵求导有区别.

3.2 A 矩阵的元素关于 X 的梯度

A m × n = X m × k W n × k T + b 1 × n A_{m \times n} = X_{m\times k}{W_{n\times k}}^T + b_{1 \times n}\\ Am×n=Xm×kWn×kT+b1×n

根据矩阵乘法行乘列的定义, 矩阵 X X X W T W^T WT 中的第 j j j 列向量相乘, 将降维得到一个新的列向量, 作为矩阵 A 中的第 j j j 列向量, 即 :
W j = ( w j 1 , w j 2 , w j 3 , ⋯   , w j k )    X W j T = ( a 1 j a 2 j a 3 j ⋮ a m j ) = A : , j W_j=(w_{j1},w_{j2},w_{j3},\cdots,w_{jk})\\ \;\\ XW_j^T= \begin{pmatrix} a_{1j}\\ a_{2j}\\ a_{3j}\\ \vdots\\ a_{mj} \end{pmatrix}=A_{:,j} Wj=(wj1,wj2,wj3,,wjk)XWjT=a1ja2ja3jamj=A:,j

上面的 : , j :,j :,j 符号表示取矩阵中 j j j 列的所有行, 结果是一个列向量. 参考的是 numpy 的记法.
矩阵 A 中任意元素的梯度 :
d a i j d X = ( ∂ a i j / ∂ x 11 ∂ a i j / ∂ x 12 ∂ a i j / ∂ x 13 ⋯ ∂ a i j / ∂ x 1 k ∂ a i j / ∂ x 21 ∂ a i j / ∂ x 22 ∂ a i j / ∂ x 23 ⋯ ∂ a i j / ∂ x 2 k ∂ a i j / ∂ x 31 ∂ a i j / ∂ x 32 ∂ a i j / ∂ x 33 ⋯ ∂ a i j / ∂ x 3 k ⋮ ⋮ ⋮ ⋱ ⋮ ∂ a i j / ∂ x m 1 ∂ a i j / ∂ x m 2 ∂ a i j / ∂ x m 3 ⋯ ∂ a i j / ∂ x m k ) \frac{d a_{ij}}{dX} = \begin{pmatrix} \partial a_{ij}/ \partial x_{11}&\partial a_{ij}/ \partial x_{12}&\partial a_{ij}/ \partial x_{13}&\cdots& \partial a_{ij}/ \partial x_{1k}\\ \partial a_{ij}/ \partial x_{21}&\partial a_{ij}/ \partial x_{22}&\partial a_{ij}/ \partial x_{23}&\cdots& \partial a_{ij}/ \partial x_{2k}\\ \partial a_{ij}/ \partial x_{31}&\partial a_{ij}/ \partial x_{32}&\partial a_{ij}/ \partial x_{33}&\cdots& \partial a_{ij}/\partial x_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \partial a_{ij}/ \partial x_{m1}&\partial a_{ij}/ \partial x_{m2}&\partial a_{ij}/ \partial x_{m3}&\cdots& \partial a_{ij}/ \partial x_{mk}\\ \end{pmatrix} dXdaij=aij/x11aij/x21aij/x31aij/xm1aij/x12aij/x22aij/x32aij/xm2aij/x13aij/x23aij/x33aij/xm3aij/x1kaij/x2kaij/x3kaij/xmk

为了书写方便, 记 :
∂ a i j ∂ x p q = x i j ∣ p q ′    ∇ a i j ( X ) = d a i j d X = ( x i j ∣ 11 ′ x i j ∣ 12 ′ x i j ∣ 13 ′ ⋯ x i j ∣ 1 k ′ x i j ∣ 21 ′ x i j ∣ 22 ′ x i j ∣ 23 ′ ⋯ x i j ∣ 2 k ′ x i j ∣ 31 ′ x i j ∣ 32 ′ x i j ∣ 33 ′ ⋯ x i j ∣ 3 k ′ ⋮ ⋮ ⋮ ⋱ ⋮ x i j ∣ m 1 ′ x i j ∣ m 2 ′ x i j ∣ m 3 ′ ⋯ x i j ∣ m k ′ ) \frac{\partial a_{ij}}{\partial x_{pq}} = x_{ij|pq}'\\ \;\\ \nabla {a_{ij}}_{(X)}=\frac{d a_{ij}}{dX} = \begin{pmatrix} x_{ij|11}'&x_{ij|12}'&x_{ij|13}'&\cdots&x_{ij|1k}'\\ x_{ij|21}'&x_{ij|22}'&x_{ij|23}'&\cdots&x_{ij|2k}'\\ x_{ij|31}'&x_{ij|32}'&x_{ij|33}'&\cdots&x_{ij|3k}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ x_{ij|m1}'&x_{ij|m2}'&x_{ij|m3}'&\cdots&x_{ij|mk}'\\ \end{pmatrix} xpqaij=xijpqaij(X)=dXdaij=xij11xij21xij31xijm1xij12xij22xij32xijm2xij13xij23xij33xijm3xij1kxij2kxij3kxijmk

3.3 关于 X 的反向传播

按照矩阵元素的定义 :
a i j = ∑ t = 1 k x i t ⋅ w j t + b j    a i j = x i 1 w j 1 + x i 2 w j 2 + ⋯ + x i q w j q + ⋯ + x i k w j k + b j    x i j ∣ p q ′ = ∂ a i j ∂ x p q = { w j q p = i 0 , p ≠ i a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{jt} +b_j\\ \;\\ a_{ij}= x_{i1}w_{j1} +x_{i2}w_{j2} +\cdots+x_{iq}w_{jq} +\cdots+x_{ik}w_{jk} +b_j\\ \;\\ x_{ij|pq}'=\frac{\partial a_{ij}}{\partial x_{pq}} = \left\{ \begin{array}{rr} w_{jq}& p = i\\ 0, & p \neq i \end{array} \right.\\ aij=t=1kxitwjt+bjaij=xi1wj1+xi2wj2++xiqwjq++xikwjk+bjxijpq=xpqaij={wjq0,p=ip̸=i

根据 <高等数学> 中介绍的复合函数求导法则, 知 :
∂ e ∂ x p q = ∑ i = 1 i = m ∑ j = 1 j = n ∂ e ∂ a i j ∂ a i j ∂ x p q = ∑ i = 1 i = m ∑ j = 1 j = n a i j ′ x i j ∣ p q ′ \frac {\partial e}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' x_{ij|pq}'\\ xpqe=i=1i=mj=1j=naijexpqaij=i=1i=mj=1j=naijxijpq

删除零项 :
∂ e ∂ x p q = ∑ j = 1 j = n a p j ′ w j q    d e d X = ( ∑ j = 1 j = n a 1 j ′ w j 1 ∑ j = 1 j = n a 1 j ′ w j 2 ∑ j = 1 j = n a 1 j ′ w j 3 ⋯ ∑ j = 1 j = n a 1 j ′ w j k    ∑ j = 1 j = n a 2 j ′ w j 1 ∑ j = 1 j = n a 2 j ′ w j 2 ∑ j = 1 j = n a 2 j ′ w j 3 ⋯ ∑ j = 1 j = n a 2 j ′ w j k    ∑ j = 1 j = n a 3 j ′ w j 1 ∑ j = 1 j = n a 3 j ′ w j 2 ∑ j = 1 j = n a 3 j ′ w j 3 ⋯ ∑ j = 1 j = n a 3 j ′ w j k ⋮ ⋮ ⋮ ⋱ ⋮ ∑ j = 1 j = n a m j ′ w j 1 ∑ j = 1 j = n a m j ′ w j 2 ∑ j = 1 j = n a m j ′ w j 3 ⋯ ∑ j = 1 j = n a m j ′ w j k ) \frac {\partial e}{\partial x_{pq}}=\sum_{j =1}^{j =n} a_{pj}'w_{jq}\\ \;\\ \frac {d e}{d X}=\begin{pmatrix} \sum_{j =1}^{j =n} a_{1j}'w_{j1}&\sum_{j =1}^{j =n} a_{1j}'w_{j2}&\sum_{j =1}^{j =n} a_{1j}'w_{j3}&\cdots&\sum_{j =1}^{j =n} a_{1j}'w_{jk}\\\;\\ \sum_{j =1}^{j =n} a_{2j}'w_{j1}&\sum_{j =1}^{j =n} a_{2j}'w_{j2}&\sum_{j =1}^{j =n} a_{2j}'w_{j3}&\cdots&\sum_{j =1}^{j =n} a_{2j}'w_{jk}\\\;\\ \sum_{j =1}^{j =n} a_{3j}'w_{j1}&\sum_{j =1}^{j =n} a_{3j}'w_{j2}&\sum_{j =1}^{j =n} a_{3j}'w_{j3}&\cdots&\sum_{j =1}^{j =n} a_{3j}'w_{jk}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \sum_{j =1}^{j =n} a_{mj}'w_{j1}&\sum_{j =1}^{j =n} a_{mj}'w_{j2}&\sum_{j =1}^{j =n} a_{mj}'w_{j3}&\cdots&\sum_{j =1}^{j =n} a_{mj}'w_{jk}\\ \end{pmatrix} xpqe=j=1j=napjwjqdXde=j=1j=na1jwj1j=1j=na2jwj1j=1j=na3jwj1j=1j=namjwj1j=1j=na1jwj2j=1j=na2jwj2j=1j=na3jwj2j=1j=namjwj2j=1j=na1jwj3j=1j=na2jwj3j=1j=na3jwj3j=1j=namjwj3j=1j=na1jwjkj=1j=na2jwjkj=1j=na3jwjkj=1j=namjwjk

这个结果恰好满足矩阵乘法的定义, 分解成矩阵 :
d e d X = ( a 11 ′ a 12 ′ a 13 ′ ⋯ a 1 n ′ a 21 ′ a 22 ′ a 23 ′ ⋯ a 2 n ′ a 31 ′ a 32 ′ a 33 ′ ⋯ a 3 n ′ ⋮ ⋮ ⋮ ⋱ ⋮ a m 1 ′ a m 2 ′ a m 3 ′ ⋯ a m n ′ ) ( w 11 w 12 w 13 ⋯ w 1 k w 21 w 22 w 23 ⋯ w 2 k w 31 w 32 w 33 ⋯ w 3 k ⋮ ⋮ ⋮ ⋱ ⋮ w n 1 w n 2 w n 3 ⋯ w n k ) \frac {d e}{d X}=\begin{pmatrix} a_{11}'& a_{12}'& a_{13}'&\cdots& a_{1n}'\\ a_{21}'& a_{22}'& a_{23}'&\cdots& a_{2n}'\\ a_{31}'& a_{32}'& a_{33}'&\cdots& a_{3n}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{m1}'& a_{m2}'& a_{m3}'&\cdots& a_{mn}' \end{pmatrix} \begin{pmatrix} w_{11}&w_{12} &w_{13}&\cdots&w_{1k}\\ w_{21}&w_{22}&w_{23}&\cdots&w_{2k}\\ w_{31}&w_{32}&w_{33}&\cdots&w_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ w_{n1}&w_{n2}&w_{n3}&\cdots&w_{nk} \end{pmatrix} dXde=a11a21a31am1a12a22a32am2a13a23a33am3a1na2na3namnw11w21w31wn1w12w22w32wn2w13w23w33wn3w1kw2kw3kwnk

所以, 损失值 e 对 X 的梯度矩阵为 :
d e d X = ∇ e ( A ) W \frac {d e}{d X} =\nabla e_{(A)}W dXde=e(A)W

矩阵 ∇ e ( A ) \nabla e_{(A)} e(A) 已在前面求得.

3.4 关于 W 的反向传播

参考上例求解 :
a i j = ∑ t = 1 k x i t ⋅ w j t + b j    a i j = x i 1 w j 1 + x i 2 w j 2 + ⋯ + x i q w j q + ⋯ + x i k w j k + b j    w i j ∣ p q ′ = ∂ a i j ∂ w p q = { x i q p = j 0 p ≠ j    ∂ e ∂ w p q = ∑ i = 1 i = m ∑ j = 1 j = n ∂ e ∂ a i j ∂ a i j ∂ w p q = ∑ i = 1 i = m ∑ j = 1 j = n a i j ′ w i j ∣ p q ′    ∂ e ∂ w p q = ∑ i = 1 i = m a i p ′ x i q    d e d W = ( ∑ i = 1 i = m a i 1 ′ x i 1 ∑ i = 1 i = m a i 1 ′ x i 2 ∑ i = 1 i = m a i 1 ′ x i 3 ⋯ ∑ i = 1 i = m a i 1 ′ x i k    ∑ i = 1 i = m a i 2 ′ x i 1 ∑ i = 1 i = m a i 2 ′ x i 2 ∑ i = 1 i = m a i 2 ′ x i 3 ⋯ ∑ i = 1 i = m a i 2 ′ x i k    ∑ i = 1 i = m a i 3 ′ x i 1 ∑ i = 3 i = m a i 3 ′ x i 2 ∑ i = 1 i = m a i 3 ′ x i 3 ⋯ ∑ i = 1 i = m a i 3 ′ x i k ⋮ ⋮ ⋮ ⋱ ⋮ ∑ i = 1 i = m a i n ′ x i 1 ∑ i = 3 i = m a i n ′ x i n ∑ i = 1 i = m a i n ′ x i 3 ⋯ ∑ i = 1 i = m a i n ′ x i k ) a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{jt} +b_j\\ \;\\ a_{ij}= x_{i1}w_{j1} +x_{i2}w_{j2} +\cdots+x_{iq}w_{jq} +\cdots+x_{ik}w_{jk} +b_j\\ \;\\ w_{ij|pq}'=\frac{\partial a_{ij}}{\partial w_{pq}} = \left\{ \begin{array}{rr} x_{iq} & p = j \\ 0 & p \neq j \end{array} \right.\\\;\\ \frac {\partial e}{\partial w_{pq}} = \sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial w_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' w_{ij|pq}'\\ \;\\ \frac {\partial e}{\partial w_{pq}}=\sum_{i =1}^{i =m} a_{ip}'x_{iq}\\ \;\\ \frac {d e}{d W}= \begin{pmatrix} \sum_{i =1}^{i =m} a_{i1}'x_{i1}&\sum_{i =1}^{i =m} a_{i1}'x_{i2}&\sum_{i =1}^{i =m} a_{i1}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{i1}'x_{ik}\\ \;\\ \sum_{i =1}^{i =m} a_{i2}'x_{i1}&\sum_{i =1}^{i =m} a_{i2}'x_{i2}&\sum_{i =1}^{i =m} a_{i2}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{i2}'x_{ik}\\ \;\\ \sum_{i =1}^{i =m} a_{i3}'x_{i1}&\sum_{i =3}^{i =m} a_{i3}'x_{i2}&\sum_{i =1}^{i =m} a_{i3}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{i3}'x_{ik}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \sum_{i =1}^{i =m} a_{in}'x_{i1}&\sum_{i =3}^{i =m} a_{in}'x_{in}&\sum_{i =1}^{i =m} a_{in}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{ik}\\ \end{pmatrix}\\ aij=t=1kxitwjt+bjaij=xi1wj1+xi2wj2++xiqwjq++xikwjk+bjwijpq=wpqaij={xiq0p=jp̸=jwpqe=i=1i=mj=1j=naijewpqaij=i=1i=mj=1j=naijwijpqwpqe=i=1i=maipxiqdWde=i=1i=mai1xi1i=1i=mai2xi1i=1i=mai3xi1i=1i=mainxi1i=1i=mai1xi2i=1i=mai2xi2i=3i=mai3xi2i=3i=mainxini=1i=mai1xi3i=1i=mai2xi3i=1i=mai3xi3i=1i=mainxi3i=1i=mai1xiki=1i=mai2xiki=1i=mai3xiki=1i=mainxik

这个结果恰好满足矩阵乘法的定义, 分解成矩阵 :
d e d W = ( a 11 ′ a 21 ′ a 31 ′ ⋯ a m 1 ′ a 12 ′ a 22 ′ a 32 ′ ⋯ a m 2 ′ a 13 ′ a 23 ′ a 33 ′ ⋯ a m 3 ′ ⋮ ⋮ ⋮ ⋱ ⋮ a 1 n ′ a 2 n ′ a 3 n ′ ⋯ a m n ′ ) ( x 11 x 12 x 13 ⋯ x 1 k x 21 x 22 x 23 ⋯ x 2 k x 31 x 32 x 33 ⋯ x 3 k ⋮ ⋮ ⋮ ⋱ ⋮ x m 1 x m 2 x m 3 ⋯ x m k ) \frac {d e}{d W}= \begin{pmatrix} a_{11}'& a_{21}'& a_{31}'&\cdots& a_{m1}'\\ a_{12}'& a_{22}'& a_{32}'&\cdots& a_{m2}'\\ a_{13}'& a_{23}'& a_{33}'&\cdots& a_{m3}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{1n}'& a_{2n}'& a_{3n}'&\cdots& a_{mn}'\\ \end{pmatrix} \begin{pmatrix} x_{11}&x_{12} &x_{13}&\cdots&x_{1k}\\ x_{21}&x_{22}&x_{23}&\cdots&x_{2k}\\ x_{31}&x_{32}&x_{33}&\cdots&x_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ x_{m1}&x_{m2}&x_{m3}&\cdots&x_{mk} \end{pmatrix} dWde=a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amnx11x21x31xm1x12x22x32xm2x13x23x33xm3x1kx2kx3kxmk

所以, 损失值 e 对 W 的梯度矩阵为 :
d e d W = ∇ e ( A ) T X \frac {d e}{d W} =\nabla e_{(A)}^TX dWde=e(A)TX

矩阵 ∇ e ( A ) \nabla e_{(A)} e(A) 已在前面求得.

3.5 关于 e 对 b 的梯度

参考上例求解 :
a i j = ∑ t = 1 k x i t ⋅ w j t + b j    b i j ∣ p ′ = ∂ a i j ∂ b q = { 1 , q = j 0 , q ≠ j    ∂ e ∂ b q = ∑ i = 1 i = m ∑ j = 1 j = n ∂ e ∂ a i j ∂ a i j ∂ b q = ∑ i = 1 i = m ∑ j = 1 j = n a i j ′ b i j ∣ q ′    ∂ e ∂ b q = ∑ i = 1 i = m a i q ′ ⋅ 1    d e d b = ( ∑ i = 1 i = m a i 1 ′ , ∑ i = 1 i = m a i 2 ′ , ∑ i = 1 i = m a i 3 ′ , ⋯   , ∑ i = 1 i = m a i m ′ ) a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{jt} +b_j\\ \;\\ b_{ij|p}'=\frac{\partial a_{ij}}{\partial b_{q}} = \left\{ \begin{array}{rr} 1,& q = j\\ 0, & q \neq j \end{array} \right.\\ \;\\ \frac {\partial e}{\partial b_{q}} = \sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial b_{q}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' b_{ij|q}'\\ \;\\ \frac {\partial e}{\partial b_{q}} = \sum_{i = 1}^{i=m} a_{iq}'\cdot 1 \\ \;\\ \frac {d e}{d b} = (\sum_{i = 1}^{i=m} a_{i1}',\sum_{i = 1}^{i=m} a_{i2}',\sum_{i = 1}^{i=m} a_{i3}', \cdots ,\sum_{i = 1}^{i=m} a_{im}')\\ aij=t=1kxitwjt+bjbijp=bqaij={1,0,q=jq̸=jbqe=i=1i=mj=1j=naijebqaij=i=1i=mj=1j=naijbijqbqe=i=1i=maiq1dbde=(i=1i=mai1,i=1i=mai2,i=1i=mai3,,i=1i=maim)

所以, 损失值 e 对 b 的梯度矩阵为 :
d e d b = s u m ( ∇ e ( A ) ,    a x i s = 0 ) \frac {de}{db}=sum(\nabla e_{(A)},\; axis=0) dbde=sum(e(A),axis=0)

矩阵 ∇ e ( A ) \nabla e_{(A)} e(A) 已在前面求得. 式中的 a x i s = 0 axis=0 axis=0 表示对矩阵的第一维求和, 参考的是 numpy 的记法.

4. Affine 的另一种定义

上文中, W 矩阵经过转置 W T W^T WT 后再参与 Affine 运算.

在目前流行的教材中, 将 W 直接进行 Affine 运算的定义也很多.
A m × n = a f f i n e ( X , W , b ) = X m × k W k × n + b 1 × n    a i j = ∑ t = 1 k x i t ⋅ w t j + b j A_{m\times n} = affine(X,W,b) = X_{m\times k}W_{k\times n} + b_{1 \times n} \;\\ a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{tj} +b_j Am×n=affine(X,W,b)=Xm×kWk×n+b1×naij=t=1kxitwtj+bj

4.1 关于 X 的反向传播

a i j = x i 1 w 1 j + x i 2 w 2 j + ⋯ + x i q w q j + ⋯ + x i k w k j + b j    x i j ∣ p q ′ = ∂ a i j ∂ x p q = { w q j p = i 0 , p ≠ i a_{ij}= x_{i1}w_{1j} +x_{i2}w_{2j} +\cdots+x_{iq}w_{qj} +\cdots+x_{ik}w_{kj} +b_j\\ \;\\ x_{ij|pq}'=\frac{\partial a_{ij}}{\partial x_{pq}} = \left\{ \begin{array}{rr} w_{qj}& p = i\\ 0, & p \neq i \end{array} \right.\\ aij=xi1w1j+xi2w2j++xiqwqj++xikwkj+bjxijpq=xpqaij={wqj0,p=ip̸=i

∂ e ∂ x p q = ∑ i = 1 i = m ∑ j = 1 j = n ∂ e ∂ a i j ∂ a i j ∂ x p q = ∑ i = 1 i = m ∑ j = 1 j = n a i j ′ x i j ∣ p q ′ \frac {\partial e}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' x_{ij|pq}'\\ xpqe=i=1i=mj=1j=naijexpqaij=i=1i=mj=1j=naijxijpq

∂ e ∂ x p q = ∑ j = 1 j = n a p j ′ w q j    d e d X = ( ∑ j = 1 j = n a 1 j ′ w 1 j ∑ j = 1 j = n a 1 j ′ w 2 j ∑ j = 1 j = n a 1 j ′ w 3 j ⋯ ∑ j = 1 j = n a 1 j ′ w k j    ∑ j = 1 j = n a 2 j ′ w 1 j ∑ j = 1 j = n a 2 j ′ w 2 j ∑ j = 1 j = n a 2 j ′ w 3 j ⋯ ∑ j = 1 j = n a 2 j ′ w k j    ∑ j = 1 j = n a 3 j ′ w 1 j ∑ j = 1 j = n a 3 j ′ w 2 j ∑ j = 1 j = n a 3 j ′ w 3 j ⋯ ∑ j = 1 j = n a 3 j ′ w k j ⋮ ⋮ ⋮ ⋱ ⋮ ∑ j = 1 j = n a m j ′ w 1 j ∑ j = 1 j = n a m j ′ w 2 j ∑ j = 1 j = n a m j ′ w 3 j ⋯ ∑ j = 1 j = n a m j ′ w k j ) \frac {\partial e}{\partial x_{pq}}=\sum_{j =1}^{j =n} a_{pj}'w_{qj}\\ \;\\ \frac {d e}{d X}=\begin{pmatrix} \sum_{j =1}^{j =n} a_{1j}'w_{1j}&\sum_{j =1}^{j =n} a_{1j}'w_{2j}&\sum_{j =1}^{j =n} a_{1j}'w_{3j}&\cdots&\sum_{j =1}^{j =n} a_{1j}'w_{kj}\\\;\\ \sum_{j =1}^{j =n} a_{2j}'w_{1j}&\sum_{j =1}^{j =n} a_{2j}'w_{2j}&\sum_{j =1}^{j =n} a_{2j}'w_{3j}&\cdots&\sum_{j =1}^{j =n} a_{2j}'w_{kj}\\\;\\ \sum_{j =1}^{j =n} a_{3j}'w_{1j}&\sum_{j =1}^{j =n} a_{3j}'w_{2j}&\sum_{j =1}^{j =n} a_{3j}'w_{3j}&\cdots&\sum_{j =1}^{j =n} a_{3j}'w_{kj}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \sum_{j =1}^{j =n} a_{mj}'w_{1j}&\sum_{j =1}^{j =n} a_{mj}'w_{2j}&\sum_{j =1}^{j =n} a_{mj}'w_{3j}&\cdots&\sum_{j =1}^{j =n} a_{mj}'w_{kj}\\ \end{pmatrix} xpqe=j=1j=napjwqjdXde=j=1j=na1jw1jj=1j=na2jw1jj=1j=na3jw1jj=1j=namjw1jj=1j=na1jw2jj=1j=na2jw2jj=1j=na3jw2jj=1j=namjw2jj=1j=na1jw3jj=1j=na2jw3jj=1j=na3jw3jj=1j=namjw3jj=1j=na1jwkjj=1j=na2jwkjj=1j=na3jwkjj=1j=namjwkj

d e d X = ( a 11 ′ a 12 ′ a 13 ′ ⋯ a 1 n ′ a 21 ′ a 22 ′ a 23 ′ ⋯ a 2 n ′ a 31 ′ a 32 ′ a 33 ′ ⋯ a 3 n ′ ⋮ ⋮ ⋮ ⋱ ⋮ a m 1 ′ a m 2 ′ a m 3 ′ ⋯ a m n ′ ) ( w 11 w 21 w 31 ⋯ w k 1 w 12 w 22 w 32 ⋯ w k 2 w 13 w 23 w 33 ⋯ w k 3 ⋮ ⋮ ⋮ ⋱ ⋮ w 1 n w 2 n w 3 n ⋯ w k n ) \frac {d e}{d X}=\begin{pmatrix} a_{11}'& a_{12}'& a_{13}'&\cdots& a_{1n}'\\ a_{21}'& a_{22}'& a_{23}'&\cdots& a_{2n}'\\ a_{31}'& a_{32}'& a_{33}'&\cdots& a_{3n}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{m1}'& a_{m2}'& a_{m3}'&\cdots& a_{mn}' \end{pmatrix} \begin{pmatrix} w_{11}&w_{21} &w_{31}&\cdots&w_{k1}\\ w_{12}&w_{22}&w_{32}&\cdots&w_{k2}\\ w_{13}&w_{23}&w_{33}&\cdots&w_{k3}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ w_{1n}&w_{2n}&w_{3n}&\cdots&w_{kn} \end{pmatrix} dXde=a11a21a31am1a12a22a32am2a13a23a33am3a1na2na3namnw11w12w13w1nw21w22w23w2nw31w32w33w3nwk1wk2wk3wkn

d e d X = ∇ e ( A ) W T \frac {d e}{d X} =\nabla e_{(A)}W^T dXde=e(A)WT

4.2 关于 W 的反向传播

a i j = x i 1 w 1 j + x i 2 w 2 j + ⋯ + x i p w p j + ⋯ + x i k w k j + b j    w i j ∣ p q ′ = ∂ a i j ∂ w p q = { x i p q = j 0 q ≠ j    ∂ e ∂ w p q = ∑ i = 1 i = m ∑ j = 1 j = n ∂ e ∂ a i j ∂ a i j ∂ w p q = ∑ i = 1 i = m ∑ j = 1 j = n a i j ′ w i j ∣ p q ′ a_{ij}= x_{i1}w_{1j} +x_{i2}w_{2j} +\cdots+x_{ip}w_{pj} +\cdots+x_{ik}w_{kj} +b_j\\ \;\\ w_{ij|pq}'=\frac{\partial a_{ij}}{\partial w_{pq}} = \left\{ \begin{array}{rr} x_{ip} & q = j \\ 0 & q \neq j \end{array} \right.\\\;\\ \frac {\partial e}{\partial w_{pq}} = \sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial w_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' w_{ij|pq}'\\ aij=xi1w1j+xi2w2j++xipwpj++xikwkj+bjwijpq=wpqaij={xip0q=jq̸=jwpqe=i=1i=mj=1j=naijewpqaij=i=1i=mj=1j=naijwijpq
∂ e ∂ w p q = ∑ i = 1 i = m a i q ′ x i p    d e d W = ( ∑ i = 1 i = m a i 1 ′ x i 1 ∑ i = 1 i = m a i 2 ′ x i 1 ∑ i = 1 i = m a i 3 ′ x i 1 ⋯ ∑ i = 1 i = m a i n ′ x i 1    ∑ i = 1 i = m a i 1 ′ x i 2 ∑ i = 1 i = m a i 2 ′ x i 2 ∑ i = 1 i = m a i 3 ′ x i 2 ⋯ ∑ i = 1 i = m a i n ′ x i 2    ∑ i = 1 i = m a i 1 ′ x i 3 ∑ i = 3 i = m a i 2 ′ x i 3 ∑ i = 1 i = m a i 3 ′ x i 3 ⋯ ∑ i = 1 i = m a i n ′ x i 3 ⋮ ⋮ ⋮ ⋱ ⋮ ∑ i = 1 i = m a i 1 ′ x i k ∑ i = 3 i = m a i 2 ′ x i k ∑ i = 1 i = m a i 3 ′ x i k ⋯ ∑ i = 1 i = m a i n ′ x i k ) \frac {\partial e}{\partial w_{pq}}=\sum_{i =1}^{i =m} a_{iq}'x_{ip}\\ \;\\ \frac {d e}{d W}= \begin{pmatrix} \sum_{i =1}^{i =m} a_{i1}'x_{i1}&\sum_{i =1}^{i =m} a_{i2}'x_{i1}&\sum_{i =1}^{i =m} a_{i3}'x_{i1}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{i1}\\ \;\\ \sum_{i =1}^{i =m} a_{i1}'x_{i2}&\sum_{i =1}^{i =m} a_{i2}'x_{i2}&\sum_{i =1}^{i =m} a_{i3}'x_{i2}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{i2}\\ \;\\ \sum_{i =1}^{i =m} a_{i1}'x_{i3}&\sum_{i =3}^{i =m} a_{i2}'x_{i3}&\sum_{i =1}^{i =m} a_{i3}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{i3}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \sum_{i =1}^{i =m} a_{i1}'x_{ik}&\sum_{i =3}^{i =m} a_{i2}'x_{ik}&\sum_{i =1}^{i =m} a_{i3}'x_{ik}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{ik}\\ \end{pmatrix}\\ wpqe=i=1i=maiqxipdWde=i=1i=mai1xi1i=1i=mai1xi2i=1i=mai1xi3i=1i=mai1xiki=1i=mai2xi1i=1i=mai2xi2i=3i=mai2xi3i=3i=mai2xiki=1i=mai3xi1i=1i=mai3xi2i=1i=mai3xi3i=1i=mai3xiki=1i=ma

你可能感兴趣的:(深度学习基础)