纯Python和PyTorch对比实现循环神经网络LSTM及反向传播

摘要

本文使用纯 Python 和 PyTorch 对比实现循环神经网络LSTM及其反向传播.

相关

配套代码, 请参考文章 :

长短期记忆网络LSTMCell单元详解及反向传播的梯度求导

文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. LSTMCell 类

文件目录 : vanilla_nn/lstmcell.py

import numpy as np


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


class LSTMCell:
    def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
        self.weight_ih = weight_ih
        self.weight_hh = weight_hh
        self.bias_ih = bias_ih
        self.bias_hh = bias_hh

        self.dc_prev = None
        self.dh_prev = None

        self.weight_ih_grad_stack = []
        self.weight_hh_grad_stack = []
        self.bias_ih_grad_stack = []
        self.bias_hh_grad_stack = []

        self.x_stack = []
        self.dx_list = []
        self.dh_prev_stack = []

        self.h_prev_stack = []
        self.c_prev_stack = []

        self.h_next_stack = []
        self.c_next_stack = []

        self.input_gate_stack = []
        self.forget_gate_stack = []
        self.output_gate_stack = []
        self.cell_memory_stack = []

    def __call__(self, x, h_prev, c_prev):
        a_vector = np.dot(x, self.weight_ih.T) + np.dot(h_prev, self.weight_hh.T)
        a_vector += self.bias_ih + self.bias_hh

        h_size = np.shape(h_prev)[1]
        a_i = a_vector[:, h_size * 0:h_size * 1]
        a_f = a_vector[:, h_size * 1:h_size * 2]
        a_c = a_vector[:, h_size * 2:h_size * 3]
        a_o = a_vector[:, h_size * 3:]

        input_gate = sigmoid(a_i)
        forget_gate = sigmoid(a_f)
        cell_memory = np.tanh(a_c)
        output_gate = sigmoid(a_o)

        c_next = (forget_gate * c_prev) + (input_gate * cell_memory)
        h_next = output_gate * np.tanh(c_next)

        self.x_stack.append(x)

        self.h_prev_stack.append(h_prev)
        self.c_prev_stack.append(c_prev)

        self.c_next_stack.append(c_next)
        self.h_next_stack.append(h_next)

        self.input_gate_stack.append(input_gate)
        self.forget_gate_stack.append(forget_gate)
        self.output_gate_stack.append(output_gate)
        self.cell_memory_stack.append(cell_memory)

        self.dc_prev = np.zeros_like(c_next)
        self.dh_prev = np.zeros_like(h_next)

        return h_next, c_next

    def backward(self, dh_next):
        x_stack = self.x_stack.pop()

        h_prev = self.h_prev_stack.pop()
        c_prev = self.c_prev_stack.pop()

        c_next = self.c_next_stack.pop()

        input_gate = self.input_gate_stack.pop()
        forget_gate = self.forget_gate_stack.pop()
        output_gate = self.output_gate_stack.pop()
        cell_memory = self.cell_memory_stack.pop()

        dh = dh_next + self.dh_prev

        d_tanh_c = dh * output_gate * (1 - np.square(np.tanh(c_next)))
        dc = d_tanh_c + self.dc_prev

        dc_prev = dc * forget_gate
        self.dc_prev = dc_prev

        d_input_gate = dc * cell_memory
        d_forget_gate = dc * c_prev
        d_cell_memory = dc * input_gate

        d_output_gate = dh * np.tanh(c_next)

        d_ai = d_input_gate * input_gate * (1 - input_gate)
        d_af = d_forget_gate * forget_gate * (1 - forget_gate)
        d_ao = d_output_gate * output_gate * (1 - output_gate)
        d_ac = d_cell_memory * (1 - np.square(cell_memory))

        da = np.concatenate((d_ai, d_af, d_ac, d_ao), axis=1)

        dx = np.dot(da, self.weight_ih)
        dh_prev = np.dot(da, self.weight_hh)
        self.dh_prev = dh_prev

        self.dx_list.insert(0, dx)
        self.dh_prev_stack.append(dh_prev)

        self.weight_ih_grad_stack.append(np.dot(da.T, x_stack))
        self.weight_hh_grad_stack.append(np.dot(da.T, h_prev))

        db = np.sum(da, axis=0)
        self.bias_ih_grad_stack.append(db)
        self.bias_hh_grad_stack.append(db)

        return dh_prev

2. LSTMCell 测试

import torch
import numpy as np
from vanilla_nn.lstmcell import LSTMCell

np.random.seed(123)
torch.random.manual_seed(123)
np.set_printoptions(precision=6, suppress=True)

lstm_torch = torch.nn.LSTMCell(2, 3).double()
lstm_numpy = LSTMCell(lstm_torch.weight_ih.data.numpy(),
                      lstm_torch.weight_hh.data.numpy(),
                      lstm_torch.bias_ih.data.numpy(),
                      lstm_torch.bias_hh.data.numpy())

x_numpy = np.random.random((4, 2))
x_torch = torch.tensor(x_numpy, requires_grad=True)

h_numpy = np.random.random((4, 3))
h_torch = torch.tensor(h_numpy, requires_grad=True)

c_numpy = np.random.random((4, 3))
c_torch = torch.tensor(c_numpy, requires_grad=True)

dh_numpy = np.random.random((4, 3))
dh_torch = torch.tensor(dh_numpy, requires_grad=True)

h_numpy, c_numpy = lstm_numpy(x_numpy, h_numpy, c_numpy)
h_torch, c_torch = lstm_torch(x_torch, (h_torch, c_torch))
h_torch.backward(dh_torch)

dh_numpy = lstm_numpy.backward(dh_numpy)

print("--- 代码输出 ---")
print("h_numpy :\n", h_numpy)
print("h_torch :\n", h_torch.data.numpy())

print("---------")
print("c_numpy :\n", c_numpy)
print("c_torch :\n", c_torch.data.numpy())

print("---------")
print("dx_numpy :\n", np.sum(lstm_numpy.dx_list, axis=0))
print("dx_torch :\n", x_torch.grad.data.numpy())

print("---------")
print("w_ih_grad_numpy :\n",
      np.sum(lstm_numpy.weight_ih_grad_stack, axis=0))
print("w_ih_grad_torch :\n",
      lstm_torch.weight_ih.grad.data.numpy())

print("---------")
print("w_hh_grad_numpy :\n",
      np.sum(lstm_numpy.weight_hh_grad_stack, axis=0))
print("w_hh_grad_torch :\n",
      lstm_torch.weight_hh.grad.data.numpy())

print("---------")
print("b_ih_grad_numpy :\n",
      np.sum(lstm_numpy.bias_ih_grad_stack, axis=0))
print("b_ih_grad_torch :\n",
      lstm_torch.bias_ih.grad.data.numpy())

print("---------")
print("b_hh_grad_numpy :\n",
      np.sum(lstm_numpy.bias_hh_grad_stack, axis=0))
print("b_hh_grad_torch :\n",
      lstm_torch.bias_hh.grad.data.numpy())

"""
--- 代码输出 ---
--- 代码输出 ---
h_numpy :
 [[ 0.055856  0.234159  0.138457]
 [ 0.094461  0.245843  0.224411]
 [ 0.020396  0.086745  0.082545]
 [-0.003794  0.040677  0.063094]]
h_torch :
 [[ 0.055856  0.234159  0.138457]
 [ 0.094461  0.245843  0.224411]
 [ 0.020396  0.086745  0.082545]
 [-0.003794  0.040677  0.063094]]
---------
c_numpy :
 [[ 0.092093  0.384992  0.213364]
 [ 0.151362  0.424671  0.318313]
 [ 0.033245  0.141979  0.120822]
 [-0.0061    0.062946  0.094999]]
c_torch :
 [[ 0.092093  0.384992  0.213364]
 [ 0.151362  0.424671  0.318313]
 [ 0.033245  0.141979  0.120822]
 [-0.0061    0.062946  0.094999]]
---------
dx_numpy :
 [[-0.144016  0.029775]
 [-0.229789  0.140921]
 [-0.246041 -0.009354]
 [-0.088844  0.036652]]
dx_torch :
 [[-0.144016  0.029775]
 [-0.229789  0.140921]
 [-0.246041 -0.009354]
 [-0.088844  0.036652]]
---------
w_ih_grad_numpy :
 [[-0.056788 -0.036448]
 [ 0.018742  0.014428]
 [ 0.007827  0.024828]
 [ 0.07856   0.05437 ]
 [ 0.061267  0.045952]
 [ 0.083886  0.0655  ]
 [ 0.229755  0.156008]
 [ 0.345218  0.251984]
 [ 0.430385  0.376664]
 [ 0.014239  0.011767]
 [ 0.054866  0.044531]
 [ 0.04654   0.048565]]
w_ih_grad_torch :
 [[-0.056788 -0.036448]
 [ 0.018742  0.014428]
 [ 0.007827  0.024828]
 [ 0.07856   0.05437 ]
 [ 0.061267  0.045952]
 [ 0.083886  0.0655  ]
 [ 0.229755  0.156008]
 [ 0.345218  0.251984]
 [ 0.430385  0.376664]
 [ 0.014239  0.011767]
 [ 0.054866  0.044531]
 [ 0.04654   0.048565]]
---------
w_hh_grad_numpy :
 [[-0.037698 -0.048568 -0.021069]
 [ 0.016749  0.016277  0.007556]
 [ 0.035743  0.02156   0.000111]
 [ 0.060824  0.069505  0.029101]
 [ 0.060402  0.051634  0.025643]
 [ 0.068116  0.06966   0.035544]
 [ 0.168965  0.217076  0.075904]
 [ 0.248277  0.290927  0.138279]
 [ 0.384974  0.401949  0.167006]
 [ 0.015448  0.0139    0.005158]
 [ 0.057147  0.048975  0.022261]
 [ 0.057297  0.048308  0.017745]]
w_hh_grad_torch :
 [[-0.037698 -0.048568 -0.021069]
 [ 0.016749  0.016277  0.007556]
 [ 0.035743  0.02156   0.000111]
 [ 0.060824  0.069505  0.029101]
 [ 0.060402  0.051634  0.025643]
 [ 0.068116  0.06966   0.035544]
 [ 0.168965  0.217076  0.075904]
 [ 0.248277  0.290927  0.138279]
 [ 0.384974  0.401949  0.167006]
 [ 0.015448  0.0139    0.005158]
 [ 0.057147  0.048975  0.022261]
 [ 0.057297  0.048308  0.017745]]
---------
b_ih_grad_numpy :
 [-0.084682  0.032588  0.046412  0.126449  0.111421  0.139337  0.361956
  0.539519  0.761838  0.027649  0.103695  0.099405]
b_ih_grad_torch :
 [-0.084682  0.032588  0.046412  0.126449  0.111421  0.139337  0.361956
  0.539519  0.761838  0.027649  0.103695  0.099405]
---------
b_hh_grad_numpy :
 [-0.084682  0.032588  0.046412  0.126449  0.111421  0.139337  0.361956
  0.539519  0.761838  0.027649  0.103695  0.099405]
b_hh_grad_torch :
 [-0.084682  0.032588  0.046412  0.126449  0.111421  0.139337  0.361956
  0.539519  0.761838  0.027649  0.103695  0.099405]
"""

3. LSTM 测试

import torch
import numpy as np
from vanilla_nn.lstmcell import LSTMCell

np.random.seed(123)
torch.random.manual_seed(123)
np.set_printoptions(precision=6, suppress=True)

lstm_torch = torch.nn.LSTM(2, 3, 1).double()
lstm_numpy = LSTMCell(lstm_torch.all_weights[0][0].data.numpy(),
                      lstm_torch.all_weights[0][1].data.numpy(),
                      lstm_torch.all_weights[0][2].data.numpy(),
                      lstm_torch.all_weights[0][3].data.numpy())

x_numpy = np.random.random((3, 4, 2))
x_torch = torch.tensor(x_numpy, requires_grad=True)

h_numpy = np.random.random((1, 4, 3))
h_torch = torch.tensor(h_numpy, requires_grad=True)

c_numpy = np.random.random((1, 4, 3))
c_torch = torch.tensor(c_numpy, requires_grad=True)

dh_numpy = np.random.random((3, 4, 3))
dh_torch = torch.tensor(dh_numpy, requires_grad=True)

out_torch, (h_torch, c_torch) = lstm_torch(x_torch, (h_torch, c_torch))
out_torch.backward(dh_torch)

h0_numpy, c0_numpy = h_numpy[0], c_numpy[0]
for i in range(3):
    h0_numpy, c0_numpy = lstm_numpy(x_numpy[i], h0_numpy, c0_numpy)

for i in reversed(range(3)):
    lstm_numpy.backward(dh_numpy[i])

print("--- 代码输出 ---")
print("out_numpy :\n", np.array(lstm_numpy.h_next_stack))
print("out_torch :\n", out_torch.data.numpy())

print("---------")
print("c_numpy :\n", c0_numpy)
print("c_torch :\n", c_torch.data.numpy())

print("---------")
print("dx_numpy :\n", np.array(lstm_numpy.dx_list))
print("dx_torch :\n", x_torch.grad.data.numpy())

print("---------")
print("w_ih_grad_numpy :\n",
      np.sum(lstm_numpy.weight_ih_grad_stack, axis=0))
print("w_ih_grad_torch :\n",
      lstm_torch.all_weights[0][0].grad.data.numpy())

print("---------")
print("w_hh_grad_numpy :\n",
      np.sum(lstm_numpy.weight_hh_grad_stack, axis=0))
print("w_hh_grad_torch :\n",
      lstm_torch.all_weights[0][1].grad.data.numpy())

print("---------")
print("b_ih_grad_numpy :\n",
      np.sum(lstm_numpy.bias_ih_grad_stack, axis=0))
print("b_ih_grad_torch :\n",
      lstm_torch.all_weights[0][2].grad.data.numpy())

print("---------")
print("b_hh_grad_numpy :\n",
      np.sum(lstm_numpy.bias_hh_grad_stack, axis=0))
print("b_hh_grad_torch :\n",
      lstm_torch.all_weights[0][3].grad.data.numpy())

"""
--- 代码输出 ---
out_numpy :
 [[[ 0.006445  0.235999  0.221529]
  [ 0.034874  0.218051  0.077521]
  [-0.030951  0.145011  0.123065]
  [-0.077813  0.132753  0.223936]]

 [[-0.090284  0.172555  0.01114 ]
  [-0.098488  0.1853    0.012856]
  [-0.082339  0.141583 -0.048317]
  [-0.142443  0.124778  0.034696]]

 [[-0.086721  0.170245 -0.067097]
  [-0.138738  0.147207 -0.077025]
  [-0.161773  0.112377 -0.097439]
  [-0.166661  0.101749 -0.087182]]]
out_torch :
 [[[ 0.006445  0.235999  0.221529]
  [ 0.034874  0.218051  0.077521]
  [-0.030951  0.145011  0.123065]
  [-0.077813  0.132753  0.223936]]

 [[-0.090284  0.172555  0.01114 ]
  [-0.098488  0.1853    0.012856]
  [-0.082339  0.141583 -0.048317]
  [-0.142443  0.124778  0.034696]]

 [[-0.086721  0.170245 -0.067097]
  [-0.138738  0.147207 -0.077025]
  [-0.161773  0.112377 -0.097439]
  [-0.166661  0.101749 -0.087182]]]
---------
c_numpy :
 [[-0.15123   0.267693 -0.113416]
 [-0.242806  0.216831 -0.118585]
 [-0.281794  0.158563 -0.139627]
 [-0.296978  0.143559 -0.132519]]
c_torch :
 [[[-0.15123   0.267693 -0.113416]
  [-0.242806  0.216831 -0.118585]
  [-0.281794  0.158563 -0.139627]
  [-0.296978  0.143559 -0.132519]]]
---------
dx_numpy :
 [[[-0.233995 -0.015056]
  [-0.314481 -0.021608]
  [-0.176089  0.00849 ]
  [-0.314323  0.070071]]

 [[-0.207276 -0.014927]
  [-0.176493 -0.061383]
  [-0.139164 -0.051902]
  [-0.209488  0.068137]]

 [[-0.104086 -0.052976]
  [-0.087995 -0.03669 ]
  [-0.094893 -0.003202]
  [-0.091384 -0.030927]]]
dx_torch :
 [[[-0.233995 -0.015056]
  [-0.314481 -0.021608]
  [-0.176089  0.00849 ]
  [-0.314323  0.070071]]

 [[-0.207276 -0.014927]
  [-0.176493 -0.061383]
  [-0.139164 -0.051902]
  [-0.209488  0.068137]]

 [[-0.104086 -0.052976]
  [-0.087995 -0.03669 ]
  [-0.094893 -0.003202]
  [-0.091384 -0.030927]]]
---------
w_ih_grad_numpy :
 [[-0.309665 -0.286389]
 [ 0.119474  0.1197  ]
 [-0.060895 -0.053575]
 [ 0.081656  0.069246]
 [ 0.343434  0.306735]
 [ 0.238902  0.180056]
 [ 0.799281  0.762242]
 [ 1.66807   1.574547]
 [ 0.878086  0.861914]
 [-0.134765 -0.128521]
 [ 0.202337  0.194739]
 [ 0.043419  0.034041]]
w_ih_grad_torch :
 [[-0.309665 -0.286389]
 [ 0.119474  0.1197  ]
 [-0.060895 -0.053575]
 [ 0.081656  0.069246]
 [ 0.343434  0.306735]
 [ 0.238902  0.180056]
 [ 0.799281  0.762242]
 [ 1.66807   1.574547]
 [ 0.878086  0.861914]
 [-0.134765 -0.128521]
 [ 0.202337  0.194739]
 [ 0.043419  0.034041]]
---------
w_hh_grad_numpy :
 [[-0.080813 -0.141149 -0.13131 ]
 [ 0.031354  0.062847  0.063331]
 [ 0.011833 -0.017525 -0.007764]
 [ 0.082655  0.059767  0.089204]
 [ 0.154359  0.176164  0.180268]
 [ 0.096947  0.111974  0.105654]
 [ 0.185689  0.399495  0.394687]
 [ 0.454164  0.778306  0.671739]
 [ 0.22264   0.398373  0.417446]
 [ 0.006327 -0.046814 -0.022775]
 [ 0.070304  0.103367  0.09488 ]
 [ 0.029992  0.027785  0.034   ]]
w_hh_grad_torch :
 [[-0.080813 -0.141149 -0.13131 ]
 [ 0.031354  0.062847  0.063331]
 [ 0.011833 -0.017525 -0.007764]
 [ 0.082655  0.059767  0.089204]
 [ 0.154359  0.176164  0.180268]
 [ 0.096947  0.111974  0.105654]
 [ 0.185689  0.399495  0.394687]
 [ 0.454164  0.778306  0.671739]
 [ 0.22264   0.398373  0.417446]
 [ 0.006327 -0.046814 -0.022775]
 [ 0.070304  0.103367  0.09488 ]
 [ 0.029992  0.027785  0.034   ]]
---------
b_ih_grad_numpy :
 [-0.545791  0.260244 -0.103962  0.141001  0.622189  0.328654  1.623599
  3.035249  1.506    -0.245668  0.394114  0.060554]
b_ih_grad_torch :
 [-0.545791  0.260244 -0.103962  0.141001  0.622189  0.328654  1.623599
  3.035249  1.506    -0.245668  0.394114  0.060554]
---------
b_hh_grad_numpy :
 [-0.545791  0.260244 -0.103962  0.141001  0.622189  0.328654  1.623599
  3.035249  1.506    -0.245668  0.394114  0.060554]
b_hh_grad_torch :
 [-0.545791  0.260244 -0.103962  0.141001  0.622189  0.328654  1.623599
  3.035249  1.506    -0.245668  0.394114  0.060554]
"""

你可能感兴趣的:(深度学习编程)