聊聊关于矩阵反向传播的梯度计算

目录

1. 前向传播

2. 反向传播

3. 矩阵反向传播

4. 总结


1. 前向传播

建立如图所示的简单网络

W 是权重矩阵,初始赋值为 2*2 的矩阵

X 是输入特征,初始赋值为 2*1 的矩阵

这样通过矩阵乘法 , Y = WX ,应该得到一个 2*1 的输出矩阵

最后定义loss 为二范数的平方,即 out = 0.22^2 + 0.26^2 = 0.116

聊聊关于矩阵反向传播的梯度计算_第1张图片

 

代码演示为:

torch.norm 是计算矩阵范数的函数

聊聊关于矩阵反向传播的梯度计算_第2张图片

 

2. 反向传播

反向传播的计算根据链式法则,这里不作数学上的推导。在计算图当中,只需要记住以下常用的即可:(注:需要注意的是,传递的值是反向传递过来的,还是正向传播输入的)

  • 加法节点:上游传回来的值直接传递到下游
  • 乘法节点:上游传回来的值,乘上输入信号的翻转值
  • Max 门:上游传来的值,只传递给输入信号的最大者,其余为0
  • ReLU : 如果输入信号大于 0,则上游直接传递;否则,为0

本章,只需要知道乘法节点计算图传递的规则即可

3. 矩阵反向传播

先将结果进行展示:

聊聊关于矩阵反向传播的梯度计算_第3张图片 

 


聊聊关于矩阵反向传播的梯度计算_第4张图片

 

首先,Y 的梯度很容易计算,Y = [0.22 0.26](转置)

因为这里out 是二范数的平方,因此out = x1^2 + x2^2 ,对Y进行偏导的话,就是2倍的关系

 


聊聊关于矩阵反向传播的梯度计算_第5张图片

 

对W和X进行计算的话,因为这里是乘法节点(W*X),因此这里需要将输入信号反转

例如求取W反向传播,应该是上游传递过来的和X的矩阵乘法 

这里只要记住反向传播的维度要和输入保持一致就行了

也就是说,目标是得到一个2*2大小的W反向传播,已经知道上游传过来的是一个2*1大小的矩阵,而将输入信号翻转的X是一个2*1大小的。那么根据矩阵乘法,只能是上游传递过来的 * X的转置

聊聊关于矩阵反向传播的梯度计算_第6张图片 

 


聊聊关于矩阵反向传播的梯度计算_第7张图片

同样的道理,对X计算反向传播

目标是得到一个2*1大小的X反向传播,已经知道上游传过来的是一个2*1大小的矩阵,而将输入翻转的W是一个2*2大小的。那么根据矩阵乘法,只能是W的转置 * 上游传递过来的值

聊聊关于矩阵反向传播的梯度计算_第8张图片

 

4. 总结

本章采用的是 W * X = Y 的方式计算。因为资料或者书籍上面有时候矩阵乘法的顺序会不一样,有的还会加上转置等等。其实这些都是为了满足矩阵乘法规则

为了不会混乱,可以这样记忆。可以不用考虑乘法的顺序或者有无转置

A * B = C 的矩阵乘法

计算谁的时候,就用反向传递的值替换掉谁,然后将另一个元素转置。顺序不变

例如:

W_{2*2} * X_{2*1} = Y_{2*1}

计算W梯度的时候,用反向传递的值替换掉W,变成 y * X(这里y是反向传递的值,本章y = 2 Y)

然后另一个元素转置,变成y * X(转置)

计算X梯度的时候,用反向传递的值替换掉X,变成 W * y(这里y是反向传递的值,本章y = 2 Y)

然后另一个元素转置,变成W(转置)* y

或者根据上游传递的信号的维度,和 输入信号翻转的维度进行矩阵计算,也可以得到正确的计算

你可能感兴趣的:(关于PyTorch,的,smart,power,神经网络,线性代数)