HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT

目录

6-1P:推导RNN反向传播算法BPTT.

RNN前向传播

沿时反向传播BPTT(Backpropagation Through Time)

设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试. 

心得体会


 

6-1P:推导RNN反向传播算法BPTT.

参考:学习笔记-循环神经网络(RNN)及沿时反向传播BPTT - 知乎

RNN前向传播

假设有一个时间序列t=1,2,...,L,在每一时刻t我们有:

\\z^{(t)}=Ux^^{t}+Wh^{(t-1)}+b \\ h^{(t)}=f(z^{(t)})\\s^{(t)}=Vh^{(t)}+c\\y^{(t)}=g(s^{(t)})

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第1张图片

这就是RNN的结构。可以看到,每一时刻t tt的隐含状态h^{(t)}都是由当前时刻的输入x^{(t)}和上一时刻的隐含状态h^{(t-1)}共同得到的。下面是详细的符号定义:

符号 含义 维度
x^{(t)} 第t时刻的输入 (K\times 1)
z^{(t)} 第t时刻隐层的带权输入 (N\times 1)
h^{(t)} 第t时刻的隐含状态 (N\times 1)
s^{(t)} 第t时刻输出层的带权输入 (M\times 1)
y^{(t)} 第t时刻的输出 (M\times 1)
E^{(t)} 第t时刻的损失 标量
U 隐层对输入的参数,整个模型共享 (N\times K)
W 隐层对状态的参数,整个模型共享 (N\times N)
V 输出层参数,整个模型共享 (M\times N)
b 隐层的偏置,整个模型共享 (N\times 1)
c 输出层偏置,整个模型共享 (M\times 1)
g() 输出层激活函数
f() 隐层激活函数

沿时反向传播BPTT(Backpropagation Through Time)

计算出E^{(t)}后,并不能立即对模型参数进行更新。需要沿着时间t tt不断给出输入,计算出所有时刻的损失。模型总损失为E=\sum_tE^{(t)}

1、求\frac{\partial E}{\partial V}

\frac{\partial E}{\partial V}=\sum_{t}^{}\frac{\partial E^{(t)}}{\partial V}

由公式 s^{(t)}=Vh^{(t)}+cy^{(t)}=g(s^{(t)}),很容易有:

\frac{\partial E^{(t)}}{\partial V_{ij}}=\frac{\partial E^{(t)}}{\partial s_i^{(t)}}\frac{\partial s_i^{(t)}}{\partial V_{ij}}=\frac{\partial E^{(t)}}{\partial y_i^{(t)}}\frac{\partial y_i^{(t)}}{\partial s_i^{(t)}}\frac{\partial s_i^{(t)}}{\partial V_{ij}}=\frac{\partial E^{(t)}}{\partial y_i^{(t)}}g'(s_i^{(t)})h_j^{(t)}

推广到矩阵形式,即:

\frac{\partial E}{\partial V}=\sum_{t}^{}[\frac{\partial E^{(t)}}{\partial y^{(t)}}\odot g'(s^{(t)})](h^{(t)})^T

2、求\frac{\partial E}{\partial U}

\frac{\partial E}{\partial U}=\sum_{t}^{}(\frac{\partial E}{\partial U})^{(t)}

细心的人会发现,与之前 (\frac{\partial E}{\partial V}=\sum_t\frac{\partial E^{(t)}}{\partial V})不同,这次的 时间上标^{(t)}加在了括号外面。简单说一下原因:由于V在输出层,所以它在每一时刻的梯度只与当前时刻的损失(E^{(t)})有关。但U和W在隐藏层,参与到了下一时刻的运算。在求它们每一时刻的梯度时,要使用总损失E来表示。

观察公式z^{(t)}=Ux^{(t)}+Wh^{(t-1)}+bh^{(t)}=f(z^{(t)}),有:

(\frac{\partial E}{\partial U_{ij}})^{(t)}=\frac{\partial E}{\partial z_i^{(t)}}\frac{\partial z_i^{(t)}}{\partial U_{ij}}=\frac{\partial E}{\partial z_i^{(t)}}x_j^{(t)}

计算\frac{\partial E}{\partial z_i^{(t)}}这一项时,由于RNN的特性:计算h^{(t)}时,同时需要x^{(t)}h^{(t-1)}。所以z^{(t)}不仅会对当前时刻的输出造成影响,也会影响到下一时刻的输出,变量间具体的依赖关系如下图所示:

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第2张图片

 所以,\frac{\partial E}{\partial z_i^{(t)}}应该包含两部分:

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第3张图片

 前半部分:

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第4张图片

 后半部分:

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第5张图片

带入原式,得到:

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第6张图片

 引入误差符号,记\delta _y^{(t)}=\frac{\partial E^{(t)}}{\partial s^{(t)}},\delta _h^{(t)}=\frac{\partial E}{\partial z^{(t)}}

上式可改写为:

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第7张图片

 推广到矩阵形式,即:

3、求\frac{\partial E}{\partial W}

 观察公式z^{(t)}=Ux^{(t)}+Wh^{(t-1)}+b,有:

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第8张图片

可发现与\frac{\partial E}{\partial U}形式基本相同。所以很容易直接得出\frac{\partial E}{\partial W}的矩阵形式:

 

通过上式比较,我们可以找到\delta _h^{(t)}的计算方式

 

设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试. 

参考链接:L5W1作业1 手把手实现循环神经网络_追寻远方的人的博客-CSDN博客

import torch
import numpy as np


class RNNCell:
    def __init__(self, weight_ih, weight_hh,
                 bias_ih, bias_hh):
        self.weight_ih = weight_ih
        self.weight_hh = weight_hh
        self.bias_ih = bias_ih
        self.bias_hh = bias_hh

        self.x_stack = []
        self.dx_list = []
        self.dw_ih_stack = []
        self.dw_hh_stack = []
        self.db_ih_stack = []
        self.db_hh_stack = []

        self.prev_hidden_stack = []
        self.next_hidden_stack = []

        # temporary cache
        self.prev_dh = None

    def __call__(self, x, prev_hidden):
        self.x_stack.append(x)

        next_h = np.tanh(
            np.dot(x, self.weight_ih.T)
            + np.dot(prev_hidden, self.weight_hh.T)
            + self.bias_ih + self.bias_hh)

        self.prev_hidden_stack.append(prev_hidden)
        self.next_hidden_stack.append(next_h)
        # clean cache
        self.prev_dh = np.zeros(next_h.shape)
        return next_h

    def backward(self, dh):
        x = self.x_stack.pop()
        prev_hidden = self.prev_hidden_stack.pop()
        next_hidden = self.next_hidden_stack.pop()

        d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2)
        self.prev_dh = np.dot(d_tanh, self.weight_hh)

        dx = np.dot(d_tanh, self.weight_ih)
        self.dx_list.insert(0, dx)

        dw_ih = np.dot(d_tanh.T, x)
        self.dw_ih_stack.append(dw_ih)

        dw_hh = np.dot(d_tanh.T, prev_hidden)
        self.dw_hh_stack.append(dw_hh)

        self.db_ih_stack.append(d_tanh)
        self.db_hh_stack.append(d_tanh)

        return self.dx_list


if __name__ == '__main__':
    np.random.seed(123)
    torch.random.manual_seed(123)
    np.set_printoptions(precision=6, suppress=True)

    rnn_PyTorch = torch.nn.RNN(4, 5).double()
    rnn_numpy = RNNCell(rnn_PyTorch.all_weights[0][0].data.numpy(),
                        rnn_PyTorch.all_weights[0][1].data.numpy(),
                        rnn_PyTorch.all_weights[0][2].data.numpy(),
                        rnn_PyTorch.all_weights[0][3].data.numpy())

    nums = 3
    x3_numpy = np.random.random((nums, 3, 4))
    x3_tensor = torch.tensor(x3_numpy, requires_grad=True)

    h3_numpy = np.random.random((1, 3, 5))
    h3_tensor = torch.tensor(h3_numpy, requires_grad=True)

    dh_numpy = np.random.random((nums, 3, 5))
    dh_tensor = torch.tensor(dh_numpy, requires_grad=True)

    h3_tensor = rnn_PyTorch(x3_tensor, h3_tensor)
    h_numpy_list = []

    h_numpy = h3_numpy[0]
    for i in range(nums):
        h_numpy = rnn_numpy(x3_numpy[i], h_numpy)
        h_numpy_list.append(h_numpy)

    h3_tensor[0].backward(dh_tensor)
    for i in reversed(range(nums)):
        rnn_numpy.backward(dh_numpy[i])

    print("numpy_hidden :\n", np.array(h_numpy_list))
    print("torch_hidden :\n", h3_tensor[0].data.numpy())
    print("-----------------------------------------------")

    print("dx_numpy :\n", np.array(rnn_numpy.dx_list))
    print("dx_torch :\n", x3_tensor.grad.data.numpy())
    print("------------------------------------------------")

    print("dw_ih_numpy :\n",
          np.sum(rnn_numpy.dw_ih_stack, axis=0))
    print("dw_ih_torch :\n",
          rnn_PyTorch.all_weights[0][0].grad.data.numpy())
    print("------------------------------------------------")

    print("dw_hh_numpy :\n",
          np.sum(rnn_numpy.dw_hh_stack, axis=0))
    print("dw_hh_torch :\n",
          rnn_PyTorch.all_weights[0][1].grad.data.numpy())
    print("------------------------------------------------")

    print("db_ih_numpy :\n",
          np.sum(rnn_numpy.db_ih_stack, axis=(0, 1)))
    print("db_ih_torch :\n",
          rnn_PyTorch.all_weights[0][2].grad.data.numpy())
    print("-----------------------------------------------")
    print("db_hh_numpy :\n",
          np.sum(rnn_numpy.db_hh_stack, axis=(0, 1)))
    print("db_hh_torch :\n",
          rnn_PyTorch.all_weights[0][3].grad.data.numpy())

numpy_hidden :
 [[[ 0.4686   -0.298203  0.741399 -0.446474  0.019391]
  [ 0.365172 -0.361254  0.426838 -0.448951  0.331553]
  [ 0.589187 -0.188248  0.684941 -0.45859   0.190099]]

 [[ 0.146213 -0.306517  0.297109  0.370957 -0.040084]
  [-0.009201 -0.365735  0.333659  0.486789  0.061897]
  [ 0.030064 -0.282985  0.42643   0.025871  0.026388]]

 [[ 0.225432 -0.015057  0.116555  0.080901  0.260097]
  [ 0.368327  0.258664  0.357446  0.177961  0.55928 ]
  [ 0.103317 -0.029123  0.182535  0.216085  0.264766]]]
torch_hidden :
 [[[ 0.4686   -0.298203  0.741399 -0.446474  0.019391]
  [ 0.365172 -0.361254  0.426838 -0.448951  0.331553]
  [ 0.589187 -0.188248  0.684941 -0.45859   0.190099]]

 [[ 0.146213 -0.306517  0.297109  0.370957 -0.040084]
  [-0.009201 -0.365735  0.333659  0.486789  0.061897]
  [ 0.030064 -0.282985  0.42643   0.025871  0.026388]]

 [[ 0.225432 -0.015057  0.116555  0.080901  0.260097]
  [ 0.368327  0.258664  0.357446  0.177961  0.55928 ]
  [ 0.103317 -0.029123  0.182535  0.216085  0.264766]]]
-----------------------------------------------
dx_numpy :
 [[[-0.643965  0.215931 -0.476378  0.072387]
  [-1.221727  0.221325 -0.757251  0.092991]
  [-0.59872  -0.065826 -0.390795  0.037424]]

 [[-0.537631 -0.303022 -0.364839  0.214627]
  [-0.815198  0.392338 -0.564135  0.217464]
  [-0.931365 -0.254144 -0.561227  0.164795]]

 [[-1.055966  0.249554 -0.623127  0.009784]
  [-0.45858   0.108994 -0.240168  0.117779]
  [-0.957469  0.315386 -0.616814  0.205634]]]
dx_torch :
 [[[-0.643965  0.215931 -0.476378  0.072387]
  [-1.221727  0.221325 -0.757251  0.092991]
  [-0.59872  -0.065826 -0.390795  0.037424]]

 [[-0.537631 -0.303022 -0.364839  0.214627]
  [-0.815198  0.392338 -0.564135  0.217464]
  [-0.931365 -0.254144 -0.561227  0.164795]]

 [[-1.055966  0.249554 -0.623127  0.009784]
  [-0.45858   0.108994 -0.240168  0.117779]
  [-0.957469  0.315386 -0.616814  0.205634]]]
------------------------------------------------
dw_ih_numpy :
 [[3.918335 2.958509 3.725173 4.157478]
 [1.261197 0.812825 1.10621  0.97753 ]
 [2.216469 1.718251 2.366936 2.324907]
 [3.85458  3.052212 3.643157 3.845696]
 [1.806807 1.50062  1.615917 1.521762]]
dw_ih_torch :
 [[3.918335 2.958509 3.725173 4.157478]
 [1.261197 0.812825 1.10621  0.97753 ]
 [2.216469 1.718251 2.366936 2.324907]
 [3.85458  3.052212 3.643157 3.845696]
 [1.806807 1.50062  1.615917 1.521762]]
------------------------------------------------
dw_hh_numpy :
 [[ 2.450078  0.243735  4.269672  0.577224  1.46911 ]
 [ 0.421015  0.372353  0.994656  0.962406  0.518992]
 [ 1.079054  0.042843  2.12169   0.863083  0.757618]
 [ 2.225794  0.188735  3.682347  0.934932  0.955984]
 [ 0.660546 -0.321076  1.554888  0.833449  0.605201]]
dw_hh_torch :
 [[ 2.450078  0.243735  4.269672  0.577224  1.46911 ]
 [ 0.421015  0.372353  0.994656  0.962406  0.518992]
 [ 1.079054  0.042843  2.12169   0.863083  0.757618]
 [ 2.225794  0.188735  3.682347  0.934932  0.955984]
 [ 0.660546 -0.321076  1.554888  0.833449  0.605201]]
------------------------------------------------
db_ih_numpy :
 [7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_ih_torch :
 [7.568411 2.175445 4.335336 6.820628 3.51003 ]
-----------------------------------------------
db_hh_numpy :
 [7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_hh_torch :
 [7.568411 2.175445 4.335336 6.820628 3.51003 ]

 

心得体会

这次作业手推了一遍BPTT,因为网课注意力不太集中,所以这次作业做起来还是有点费劲。

在我参考的文章学习笔记-循环神经网络(RNN)及沿时反向传播BPTT - 知乎中,作者提到了将BPTT的误差分为了网络方向上的误差与时间上的误差,我认为这样更易于理解BPTT。

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第9张图片

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第10张图片

HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_第11张图片 

 

你可能感兴趣的:(pytorch,深度学习,python,rnn)