HBU_神经网络与深度学习 作业9 随时间反向传播算法的实现

目录

  • 写在前面的一些内容
  • 习题1
  • 习题2

写在前面的一些内容

本次习题来源于 NNDL 作业9:分别使用numpy和pytorch实现BPTT 。
水平有限,难免有误,如有错漏之处敬请指正。

习题1

推导循环神经网络反向传播算法BPTT.



HBU_神经网络与深度学习 作业9 随时间反向传播算法的实现_第1张图片
HBU_神经网络与深度学习 作业9 随时间反向传播算法的实现_第2张图片
一些已知的东西:
z 1 = U h 0 + W x 1 + b z 2 = U h 1 + W x 2 + b z 3 = U h 2 + W x 3 + b h 1 = f ( z 1 ) y 1 ^ = g ( h 1 ) z_1=Uh_0+Wx_1+b \\ z_2=Uh_1+Wx_2+b \\ z_3=Uh_2+Wx_3+b \\ h_1=f(z_1) \\ \widehat{y_1}=g(h_1) z1=Uh0+Wx1+bz2=Uh1+Wx2+bz3=Uh2+Wx3+bh1=f(z1)y1 =g(h1) T = 1 T=1 T=1时, ∂ L ∂ U \frac{\partial\mathcal{L}}{\partial U} UL的情况
∂ L 1 ∂ U = ∂ L 1 ∂ y 1 ^ ⋅ ∂ y 1 ^ ∂ h 1 ⋅ ∂ h 1 ∂ z 1 ⋅ ∂ z 1 ∂ U \frac{\partial\mathcal{L}_1}{\partial U}= \frac{\partial\mathcal{L}_1}{\partial\widehat{y_1}}\cdot \frac{\partial\widehat{y_1}}{\partial h_1}\cdot \frac{\partial h_1}{\partial z_1}\cdot \frac{\partial z_1}{\partial U} UL1=y1 L1h1y1 z1h1Uz1 T = 2 T=2 T=2时, ∂ L ∂ U \frac{\partial\mathcal{L}}{\partial U} UL的情况
∂ L 2 ∂ U = ∂ L 2 ∂ y 2 ^ ⋅ ∂ y 2 ^ ∂ h 2 ⋅ ∂ h 2 ∂ z 2 ⋅ ∂ z 2 ∂ U + ∂ L 2 ∂ y 2 ^ ⋅ ∂ y 2 ^ ∂ h 2 ⋅ ∂ h 2 ∂ z 2 ⋅ ∂ z 2 ∂ h 1 ⋅ ∂ h 1 ∂ z 1 ⋅ ∂ z 1 ∂ U \frac{\partial\mathcal{L}_2}{\partial U}= \frac{\partial\mathcal{L}_2}{\partial\widehat{y_2}}\cdot \frac{\partial\widehat{y_2}}{\partial h_2}\cdot \frac{\partial h_2}{\partial z_2}\cdot \frac{\partial z_2}{\partial U} + \frac{\partial\mathcal{L}_2}{\partial\widehat{y_2}}\cdot \frac{\partial\widehat{y_2}}{\partial h_2}\cdot \frac{\partial h_2}{\partial z_2}\cdot \frac{\partial z_2}{\partial h_1}\cdot \frac{\partial h_1}{\partial z_1}\cdot \frac{\partial z_1}{\partial U} UL2=y2 L2h2y2 z2h2Uz2+y2 L2h2y2 z2h2h1z2z1h1Uz1 T = 3 T=3 T=3时, ∂ L ∂ U \frac{\partial\mathcal{L}}{\partial U} UL的情况
∂ L 3 ∂ U = ∂ L 3 ∂ y 3 ^ ⋅ ∂ y 3 ^ ∂ h 3 ⋅ ∂ h 3 ∂ z 3 ⋅ ∂ z 3 ∂ U + ∂ L 3 ∂ y 3 ^ ⋅ ∂ y 3 ^ ∂ h 3 ⋅ ∂ h 3 ∂ z 3 ⋅ ∂ z 3 ∂ h 2 ⋅ ∂ h 2 ∂ z 2 ⋅ ∂ z 2 ∂ U + ∂ L 3 ∂ y 3 ^ ⋅ ∂ y 3 ^ ∂ h 3 ⋅ ∂ h 3 ∂ z 3 ⋅ ∂ z 3 ∂ h 2 ⋅ ∂ h 2 ∂ z 2 ⋅ ∂ z 2 ∂ h 1 ⋅ ∂ h 1 ∂ z 1 ⋅ ∂ z 1 ∂ U \frac{\partial\mathcal{L}_3}{\partial U}= \frac{\partial\mathcal{L}_3}{\partial\widehat{y_3}}\cdot \frac{\partial\widehat{y_3}}{\partial h_3}\cdot \frac{\partial h_3}{\partial z_3}\cdot \frac{\partial z_3}{\partial U} + \frac{\partial\mathcal{L}_3}{\partial\widehat{y_3}}\cdot \frac{\partial\widehat{y_3}}{\partial h_3}\cdot \frac{\partial h_3}{\partial z_3}\cdot \frac{\partial z_3}{\partial h_2}\cdot \frac{\partial h_2}{\partial z_2}\cdot \frac{\partial z_2}{\partial U} + \frac{\partial\mathcal{L}_3}{\partial\widehat{y_3}}\cdot \frac{\partial\widehat{y_3}}{\partial h_3}\cdot \frac{\partial h_3}{\partial z_3}\cdot \frac{\partial z_3}{\partial h_2}\cdot \frac{\partial h_2}{\partial z_2}\cdot \frac{\partial z_2}{\partial h_1}\cdot \frac{\partial h_1}{\partial z_1}\cdot \frac{\partial z_1}{\partial U} UL3=y3 L3h3y3 z3h3Uz3+y3 L3h3y3 z3h3h2z3z2h2Uz2+y3 L3h3y3 z3h3h2z3z2h2h1z2z1h1Uz1以此类推,可得
∂ L ∂ U = ∑ t = 1 T ∂ L t ∂ U \frac{\partial\mathcal{L}}{\partial U}=\sum_{t=1}^T\frac{\partial\mathcal{L}_t}{\partial U} UL=t=1TULt δ = ∂ L ∂ z \delta=\frac{\partial\mathcal{L}}{\partial z} δ=zL h = ∂ z ∂ U h=\frac{\partial z}{\partial U} h=Uz,则有
∂ L t ∂ U = ∑ k = 1 t δ t , k h k − 1 T \frac{\partial\mathcal{L}_t}{\partial U}=\sum_{k=1}^t \delta_{t,k}h_{k-1}^T ULt=k=1tδt,khk1T进而
∂ L ∂ U = ∑ t = 1 T ∑ k = 1 t δ t , k h k − 1 T \frac{\partial\mathcal{L}}{\partial U}=\sum_{t=1}^T \sum_{k=1}^t \delta_{t,k}h_{k-1}^T UL=t=1Tk=1tδt,khk1T同理可得
∂ L ∂ W = ∑ t = 1 T ∑ k = 1 t δ t , k x k T ∂ L ∂ b = ∑ t = 1 T ∑ k = 1 t δ t , k \frac{\partial\mathcal{L}}{\partial W}=\sum_{t=1}^T \sum_{k=1}^t \delta_{t,k} x_k^T\\ \frac{\partial\mathcal{L}}{\partial b}=\sum_{t=1}^T \sum_{k=1}^t \delta_{t,k} WL=t=1Tk=1tδt,kxkTbL=t=1Tk=1tδt,k

习题2

设计简单循环神经网络模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试。

代码实现如下:

import torch
import numpy as np
 
 
class RNNCell:
    def __init__(self, weight_ih, weight_hh,
                 bias_ih, bias_hh):
        self.weight_ih = weight_ih
        self.weight_hh = weight_hh
        self.bias_ih = bias_ih
        self.bias_hh = bias_hh
 
        self.x_stack = []
        self.dx_list = []
        self.dw_ih_stack = []
        self.dw_hh_stack = []
        self.db_ih_stack = []
        self.db_hh_stack = []
 
        self.prev_hidden_stack = []
        self.next_hidden_stack = []
 
        # temporary cache
        self.prev_dh = None
 
    def __call__(self, x, prev_hidden):
        self.x_stack.append(x)
 
        next_h = np.tanh(
            np.dot(x, self.weight_ih.T)
            + np.dot(prev_hidden, self.weight_hh.T)
            + self.bias_ih + self.bias_hh)
 
        self.prev_hidden_stack.append(prev_hidden)
        self.next_hidden_stack.append(next_h)
        # clean cache
        self.prev_dh = np.zeros(next_h.shape)
        return next_h
 
    def backward(self, dh):
        x = self.x_stack.pop()
        prev_hidden = self.prev_hidden_stack.pop()
        next_hidden = self.next_hidden_stack.pop()
 
        d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2)
        self.prev_dh = np.dot(d_tanh, self.weight_hh)
 
        dx = np.dot(d_tanh, self.weight_ih)
        self.dx_list.insert(0, dx)
 
        dw_ih = np.dot(d_tanh.T, x)
        self.dw_ih_stack.append(dw_ih)
 
        dw_hh = np.dot(d_tanh.T, prev_hidden)
        self.dw_hh_stack.append(dw_hh)
 
        self.db_ih_stack.append(d_tanh)
        self.db_hh_stack.append(d_tanh)
 
        return self.dx_list
 
 
if __name__ == '__main__':
    np.random.seed(123)
    torch.random.manual_seed(123)
    np.set_printoptions(precision=6, suppress=True)
 
    rnn_PyTorch = torch.nn.RNN(4, 5).double()
    rnn_numpy = RNNCell(rnn_PyTorch.all_weights[0][0].data.numpy(),
                        rnn_PyTorch.all_weights[0][1].data.numpy(),
                        rnn_PyTorch.all_weights[0][2].data.numpy(),
                        rnn_PyTorch.all_weights[0][3].data.numpy())
 
    nums = 3
    x3_numpy = np.random.random((nums, 3, 4))
    x3_tensor = torch.tensor(x3_numpy, requires_grad=True)
 
    h3_numpy = np.random.random((1, 3, 5))
    h3_tensor = torch.tensor(h3_numpy, requires_grad=True)
 
    dh_numpy = np.random.random((nums, 3, 5))
    dh_tensor = torch.tensor(dh_numpy, requires_grad=True)
 
    h3_tensor = rnn_PyTorch(x3_tensor, h3_tensor)
    h_numpy_list = []
 
    h_numpy = h3_numpy[0]
    for i in range(nums):
        h_numpy = rnn_numpy(x3_numpy[i], h_numpy)
        h_numpy_list.append(h_numpy)
 
    h3_tensor[0].backward(dh_tensor)
    for i in reversed(range(nums)):
        rnn_numpy.backward(dh_numpy[i])
 
    print("numpy_hidden :\n", np.array(h_numpy_list))
    print("torch_hidden :\n", h3_tensor[0].data.numpy())
    print("-----------------------------------------------")
 
    print("dx_numpy :\n", np.array(rnn_numpy.dx_list))
    print("dx_torch :\n", x3_tensor.grad.data.numpy())
    print("------------------------------------------------")
 
    print("dw_ih_numpy :\n",
          np.sum(rnn_numpy.dw_ih_stack, axis=0))
    print("dw_ih_torch :\n",
          rnn_PyTorch.all_weights[0][0].grad.data.numpy())
    print("------------------------------------------------")
 
    print("dw_hh_numpy :\n",
          np.sum(rnn_numpy.dw_hh_stack, axis=0))
    print("dw_hh_torch :\n",
          rnn_PyTorch.all_weights[0][1].grad.data.numpy())
    print("------------------------------------------------")
 
    print("db_ih_numpy :\n",
          np.sum(rnn_numpy.db_ih_stack, axis=(0, 1)))
    print("db_ih_torch :\n",
          rnn_PyTorch.all_weights[0][2].grad.data.numpy())
    print("-----------------------------------------------")
    print("db_hh_numpy :\n",
          np.sum(rnn_numpy.db_hh_stack, axis=(0, 1)))
    print("db_hh_torch :\n",
          rnn_PyTorch.all_weights[0][3].grad.data.numpy())

代码执行结果:

numpy_hidden :
 [[[ 0.4686   -0.298203  0.741399 -0.446474  0.019391]
  [ 0.365172 -0.361254  0.426838 -0.448951  0.331553]
  [ 0.589187 -0.188248  0.684941 -0.45859   0.190099]]

 [[ 0.146213 -0.306517  0.297109  0.370957 -0.040084]
  [-0.009201 -0.365735  0.333659  0.486789  0.061897]
  [ 0.030064 -0.282985  0.42643   0.025871  0.026388]]

 [[ 0.225432 -0.015057  0.116555  0.080901  0.260097]
  [ 0.368327  0.258664  0.357446  0.177961  0.55928 ]
  [ 0.103317 -0.029123  0.182535  0.216085  0.264766]]]
torch_hidden :
 [[[ 0.4686   -0.298203  0.741399 -0.446474  0.019391]
  [ 0.365172 -0.361254  0.426838 -0.448951  0.331553]
  [ 0.589187 -0.188248  0.684941 -0.45859   0.190099]]

 [[ 0.146213 -0.306517  0.297109  0.370957 -0.040084]
  [-0.009201 -0.365735  0.333659  0.486789  0.061897]
  [ 0.030064 -0.282985  0.42643   0.025871  0.026388]]

 [[ 0.225432 -0.015057  0.116555  0.080901  0.260097]
  [ 0.368327  0.258664  0.357446  0.177961  0.55928 ]
  [ 0.103317 -0.029123  0.182535  0.216085  0.264766]]]
-----------------------------------------------
dx_numpy :
 [[[-0.643965  0.215931 -0.476378  0.072387]
  [-1.221727  0.221325 -0.757251  0.092991]
  [-0.59872  -0.065826 -0.390795  0.037424]]

 [[-0.537631 -0.303022 -0.364839  0.214627]
  [-0.815198  0.392338 -0.564135  0.217464]
  [-0.931365 -0.254144 -0.561227  0.164795]]

 [[-1.055966  0.249554 -0.623127  0.009784]
  [-0.45858   0.108994 -0.240168  0.117779]
  [-0.957469  0.315386 -0.616814  0.205634]]]
dx_torch :
 [[[-0.643965  0.215931 -0.476378  0.072387]
  [-1.221727  0.221325 -0.757251  0.092991]
  [-0.59872  -0.065826 -0.390795  0.037424]]

 [[-0.537631 -0.303022 -0.364839  0.214627]
  [-0.815198  0.392338 -0.564135  0.217464]
  [-0.931365 -0.254144 -0.561227  0.164795]]

 [[-1.055966  0.249554 -0.623127  0.009784]
  [-0.45858   0.108994 -0.240168  0.117779]
  [-0.957469  0.315386 -0.616814  0.205634]]]
------------------------------------------------
dw_ih_numpy :
 [[3.918335 2.958509 3.725173 4.157478]
 [1.261197 0.812825 1.10621  0.97753 ]
 [2.216469 1.718251 2.366936 2.324907]
 [3.85458  3.052212 3.643157 3.845696]
 [1.806807 1.50062  1.615917 1.521762]]
dw_ih_torch :
 [[3.918335 2.958509 3.725173 4.157478]
 [1.261197 0.812825 1.10621  0.97753 ]
 [2.216469 1.718251 2.366936 2.324907]
 [3.85458  3.052212 3.643157 3.845696]
 [1.806807 1.50062  1.615917 1.521762]]
------------------------------------------------
dw_hh_numpy :
 [[ 2.450078  0.243735  4.269672  0.577224  1.46911 ]
 [ 0.421015  0.372353  0.994656  0.962406  0.518992]
 [ 1.079054  0.042843  2.12169   0.863083  0.757618]
 [ 2.225794  0.188735  3.682347  0.934932  0.955984]
 [ 0.660546 -0.321076  1.554888  0.833449  0.605201]]
dw_hh_torch :
 [[ 2.450078  0.243735  4.269672  0.577224  1.46911 ]
 [ 0.421015  0.372353  0.994656  0.962406  0.518992]
 [ 1.079054  0.042843  2.12169   0.863083  0.757618]
 [ 2.225794  0.188735  3.682347  0.934932  0.955984]
 [ 0.660546 -0.321076  1.554888  0.833449  0.605201]]
------------------------------------------------
db_ih_numpy :
 [7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_ih_torch :
 [7.568411 2.175445 4.335336 6.820628 3.51003 ]
-----------------------------------------------
db_hh_numpy :
 [7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_hh_torch :
 [7.568411 2.175445 4.335336 6.820628 3.51003 ]

你可能感兴趣的:(人工智能,深度学习,神经网络,python)