本文使用纯 Python 和 PyTorch 对比实现循环神经网络LSTM及其反向传播.
配套代码, 请参考文章 :
长短期记忆网络LSTMCell单元详解及反向传播的梯度求导
文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
文件目录 : vanilla_nn/lstmcell.py
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
class LSTMCell:
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.dc_prev = None
self.dh_prev = None
self.weight_ih_grad_stack = []
self.weight_hh_grad_stack = []
self.bias_ih_grad_stack = []
self.bias_hh_grad_stack = []
self.x_stack = []
self.dx_list = []
self.dh_prev_stack = []
self.h_prev_stack = []
self.c_prev_stack = []
self.h_next_stack = []
self.c_next_stack = []
self.input_gate_stack = []
self.forget_gate_stack = []
self.output_gate_stack = []
self.cell_memory_stack = []
def __call__(self, x, h_prev, c_prev):
a_vector = np.dot(x, self.weight_ih.T) + np.dot(h_prev, self.weight_hh.T)
a_vector += self.bias_ih + self.bias_hh
h_size = np.shape(h_prev)[1]
a_i = a_vector[:, h_size * 0:h_size * 1]
a_f = a_vector[:, h_size * 1:h_size * 2]
a_c = a_vector[:, h_size * 2:h_size * 3]
a_o = a_vector[:, h_size * 3:]
input_gate = sigmoid(a_i)
forget_gate = sigmoid(a_f)
cell_memory = np.tanh(a_c)
output_gate = sigmoid(a_o)
c_next = (forget_gate * c_prev) + (input_gate * cell_memory)
h_next = output_gate * np.tanh(c_next)
self.x_stack.append(x)
self.h_prev_stack.append(h_prev)
self.c_prev_stack.append(c_prev)
self.c_next_stack.append(c_next)
self.h_next_stack.append(h_next)
self.input_gate_stack.append(input_gate)
self.forget_gate_stack.append(forget_gate)
self.output_gate_stack.append(output_gate)
self.cell_memory_stack.append(cell_memory)
self.dc_prev = np.zeros_like(c_next)
self.dh_prev = np.zeros_like(h_next)
return h_next, c_next
def backward(self, dh_next):
x_stack = self.x_stack.pop()
h_prev = self.h_prev_stack.pop()
c_prev = self.c_prev_stack.pop()
c_next = self.c_next_stack.pop()
input_gate = self.input_gate_stack.pop()
forget_gate = self.forget_gate_stack.pop()
output_gate = self.output_gate_stack.pop()
cell_memory = self.cell_memory_stack.pop()
dh = dh_next + self.dh_prev
d_tanh_c = dh * output_gate * (1 - np.square(np.tanh(c_next)))
dc = d_tanh_c + self.dc_prev
dc_prev = dc * forget_gate
self.dc_prev = dc_prev
d_input_gate = dc * cell_memory
d_forget_gate = dc * c_prev
d_cell_memory = dc * input_gate
d_output_gate = dh * np.tanh(c_next)
d_ai = d_input_gate * input_gate * (1 - input_gate)
d_af = d_forget_gate * forget_gate * (1 - forget_gate)
d_ao = d_output_gate * output_gate * (1 - output_gate)
d_ac = d_cell_memory * (1 - np.square(cell_memory))
da = np.concatenate((d_ai, d_af, d_ac, d_ao), axis=1)
dx = np.dot(da, self.weight_ih)
dh_prev = np.dot(da, self.weight_hh)
self.dh_prev = dh_prev
self.dx_list.insert(0, dx)
self.dh_prev_stack.append(dh_prev)
self.weight_ih_grad_stack.append(np.dot(da.T, x_stack))
self.weight_hh_grad_stack.append(np.dot(da.T, h_prev))
db = np.sum(da, axis=0)
self.bias_ih_grad_stack.append(db)
self.bias_hh_grad_stack.append(db)
return dh_prev
import torch
import numpy as np
from vanilla_nn.lstmcell import LSTMCell
np.random.seed(123)
torch.random.manual_seed(123)
np.set_printoptions(precision=6, suppress=True)
lstm_torch = torch.nn.LSTMCell(2, 3).double()
lstm_numpy = LSTMCell(lstm_torch.weight_ih.data.numpy(),
lstm_torch.weight_hh.data.numpy(),
lstm_torch.bias_ih.data.numpy(),
lstm_torch.bias_hh.data.numpy())
x_numpy = np.random.random((4, 2))
x_torch = torch.tensor(x_numpy, requires_grad=True)
h_numpy = np.random.random((4, 3))
h_torch = torch.tensor(h_numpy, requires_grad=True)
c_numpy = np.random.random((4, 3))
c_torch = torch.tensor(c_numpy, requires_grad=True)
dh_numpy = np.random.random((4, 3))
dh_torch = torch.tensor(dh_numpy, requires_grad=True)
h_numpy, c_numpy = lstm_numpy(x_numpy, h_numpy, c_numpy)
h_torch, c_torch = lstm_torch(x_torch, (h_torch, c_torch))
h_torch.backward(dh_torch)
dh_numpy = lstm_numpy.backward(dh_numpy)
print("--- 代码输出 ---")
print("h_numpy :\n", h_numpy)
print("h_torch :\n", h_torch.data.numpy())
print("---------")
print("c_numpy :\n", c_numpy)
print("c_torch :\n", c_torch.data.numpy())
print("---------")
print("dx_numpy :\n", np.sum(lstm_numpy.dx_list, axis=0))
print("dx_torch :\n", x_torch.grad.data.numpy())
print("---------")
print("w_ih_grad_numpy :\n",
np.sum(lstm_numpy.weight_ih_grad_stack, axis=0))
print("w_ih_grad_torch :\n",
lstm_torch.weight_ih.grad.data.numpy())
print("---------")
print("w_hh_grad_numpy :\n",
np.sum(lstm_numpy.weight_hh_grad_stack, axis=0))
print("w_hh_grad_torch :\n",
lstm_torch.weight_hh.grad.data.numpy())
print("---------")
print("b_ih_grad_numpy :\n",
np.sum(lstm_numpy.bias_ih_grad_stack, axis=0))
print("b_ih_grad_torch :\n",
lstm_torch.bias_ih.grad.data.numpy())
print("---------")
print("b_hh_grad_numpy :\n",
np.sum(lstm_numpy.bias_hh_grad_stack, axis=0))
print("b_hh_grad_torch :\n",
lstm_torch.bias_hh.grad.data.numpy())
"""
--- 代码输出 ---
--- 代码输出 ---
h_numpy :
[[ 0.055856 0.234159 0.138457]
[ 0.094461 0.245843 0.224411]
[ 0.020396 0.086745 0.082545]
[-0.003794 0.040677 0.063094]]
h_torch :
[[ 0.055856 0.234159 0.138457]
[ 0.094461 0.245843 0.224411]
[ 0.020396 0.086745 0.082545]
[-0.003794 0.040677 0.063094]]
---------
c_numpy :
[[ 0.092093 0.384992 0.213364]
[ 0.151362 0.424671 0.318313]
[ 0.033245 0.141979 0.120822]
[-0.0061 0.062946 0.094999]]
c_torch :
[[ 0.092093 0.384992 0.213364]
[ 0.151362 0.424671 0.318313]
[ 0.033245 0.141979 0.120822]
[-0.0061 0.062946 0.094999]]
---------
dx_numpy :
[[-0.144016 0.029775]
[-0.229789 0.140921]
[-0.246041 -0.009354]
[-0.088844 0.036652]]
dx_torch :
[[-0.144016 0.029775]
[-0.229789 0.140921]
[-0.246041 -0.009354]
[-0.088844 0.036652]]
---------
w_ih_grad_numpy :
[[-0.056788 -0.036448]
[ 0.018742 0.014428]
[ 0.007827 0.024828]
[ 0.07856 0.05437 ]
[ 0.061267 0.045952]
[ 0.083886 0.0655 ]
[ 0.229755 0.156008]
[ 0.345218 0.251984]
[ 0.430385 0.376664]
[ 0.014239 0.011767]
[ 0.054866 0.044531]
[ 0.04654 0.048565]]
w_ih_grad_torch :
[[-0.056788 -0.036448]
[ 0.018742 0.014428]
[ 0.007827 0.024828]
[ 0.07856 0.05437 ]
[ 0.061267 0.045952]
[ 0.083886 0.0655 ]
[ 0.229755 0.156008]
[ 0.345218 0.251984]
[ 0.430385 0.376664]
[ 0.014239 0.011767]
[ 0.054866 0.044531]
[ 0.04654 0.048565]]
---------
w_hh_grad_numpy :
[[-0.037698 -0.048568 -0.021069]
[ 0.016749 0.016277 0.007556]
[ 0.035743 0.02156 0.000111]
[ 0.060824 0.069505 0.029101]
[ 0.060402 0.051634 0.025643]
[ 0.068116 0.06966 0.035544]
[ 0.168965 0.217076 0.075904]
[ 0.248277 0.290927 0.138279]
[ 0.384974 0.401949 0.167006]
[ 0.015448 0.0139 0.005158]
[ 0.057147 0.048975 0.022261]
[ 0.057297 0.048308 0.017745]]
w_hh_grad_torch :
[[-0.037698 -0.048568 -0.021069]
[ 0.016749 0.016277 0.007556]
[ 0.035743 0.02156 0.000111]
[ 0.060824 0.069505 0.029101]
[ 0.060402 0.051634 0.025643]
[ 0.068116 0.06966 0.035544]
[ 0.168965 0.217076 0.075904]
[ 0.248277 0.290927 0.138279]
[ 0.384974 0.401949 0.167006]
[ 0.015448 0.0139 0.005158]
[ 0.057147 0.048975 0.022261]
[ 0.057297 0.048308 0.017745]]
---------
b_ih_grad_numpy :
[-0.084682 0.032588 0.046412 0.126449 0.111421 0.139337 0.361956
0.539519 0.761838 0.027649 0.103695 0.099405]
b_ih_grad_torch :
[-0.084682 0.032588 0.046412 0.126449 0.111421 0.139337 0.361956
0.539519 0.761838 0.027649 0.103695 0.099405]
---------
b_hh_grad_numpy :
[-0.084682 0.032588 0.046412 0.126449 0.111421 0.139337 0.361956
0.539519 0.761838 0.027649 0.103695 0.099405]
b_hh_grad_torch :
[-0.084682 0.032588 0.046412 0.126449 0.111421 0.139337 0.361956
0.539519 0.761838 0.027649 0.103695 0.099405]
"""
import torch
import numpy as np
from vanilla_nn.lstmcell import LSTMCell
np.random.seed(123)
torch.random.manual_seed(123)
np.set_printoptions(precision=6, suppress=True)
lstm_torch = torch.nn.LSTM(2, 3, 1).double()
lstm_numpy = LSTMCell(lstm_torch.all_weights[0][0].data.numpy(),
lstm_torch.all_weights[0][1].data.numpy(),
lstm_torch.all_weights[0][2].data.numpy(),
lstm_torch.all_weights[0][3].data.numpy())
x_numpy = np.random.random((3, 4, 2))
x_torch = torch.tensor(x_numpy, requires_grad=True)
h_numpy = np.random.random((1, 4, 3))
h_torch = torch.tensor(h_numpy, requires_grad=True)
c_numpy = np.random.random((1, 4, 3))
c_torch = torch.tensor(c_numpy, requires_grad=True)
dh_numpy = np.random.random((3, 4, 3))
dh_torch = torch.tensor(dh_numpy, requires_grad=True)
out_torch, (h_torch, c_torch) = lstm_torch(x_torch, (h_torch, c_torch))
out_torch.backward(dh_torch)
h0_numpy, c0_numpy = h_numpy[0], c_numpy[0]
for i in range(3):
h0_numpy, c0_numpy = lstm_numpy(x_numpy[i], h0_numpy, c0_numpy)
for i in reversed(range(3)):
lstm_numpy.backward(dh_numpy[i])
print("--- 代码输出 ---")
print("out_numpy :\n", np.array(lstm_numpy.h_next_stack))
print("out_torch :\n", out_torch.data.numpy())
print("---------")
print("c_numpy :\n", c0_numpy)
print("c_torch :\n", c_torch.data.numpy())
print("---------")
print("dx_numpy :\n", np.array(lstm_numpy.dx_list))
print("dx_torch :\n", x_torch.grad.data.numpy())
print("---------")
print("w_ih_grad_numpy :\n",
np.sum(lstm_numpy.weight_ih_grad_stack, axis=0))
print("w_ih_grad_torch :\n",
lstm_torch.all_weights[0][0].grad.data.numpy())
print("---------")
print("w_hh_grad_numpy :\n",
np.sum(lstm_numpy.weight_hh_grad_stack, axis=0))
print("w_hh_grad_torch :\n",
lstm_torch.all_weights[0][1].grad.data.numpy())
print("---------")
print("b_ih_grad_numpy :\n",
np.sum(lstm_numpy.bias_ih_grad_stack, axis=0))
print("b_ih_grad_torch :\n",
lstm_torch.all_weights[0][2].grad.data.numpy())
print("---------")
print("b_hh_grad_numpy :\n",
np.sum(lstm_numpy.bias_hh_grad_stack, axis=0))
print("b_hh_grad_torch :\n",
lstm_torch.all_weights[0][3].grad.data.numpy())
"""
--- 代码输出 ---
out_numpy :
[[[ 0.006445 0.235999 0.221529]
[ 0.034874 0.218051 0.077521]
[-0.030951 0.145011 0.123065]
[-0.077813 0.132753 0.223936]]
[[-0.090284 0.172555 0.01114 ]
[-0.098488 0.1853 0.012856]
[-0.082339 0.141583 -0.048317]
[-0.142443 0.124778 0.034696]]
[[-0.086721 0.170245 -0.067097]
[-0.138738 0.147207 -0.077025]
[-0.161773 0.112377 -0.097439]
[-0.166661 0.101749 -0.087182]]]
out_torch :
[[[ 0.006445 0.235999 0.221529]
[ 0.034874 0.218051 0.077521]
[-0.030951 0.145011 0.123065]
[-0.077813 0.132753 0.223936]]
[[-0.090284 0.172555 0.01114 ]
[-0.098488 0.1853 0.012856]
[-0.082339 0.141583 -0.048317]
[-0.142443 0.124778 0.034696]]
[[-0.086721 0.170245 -0.067097]
[-0.138738 0.147207 -0.077025]
[-0.161773 0.112377 -0.097439]
[-0.166661 0.101749 -0.087182]]]
---------
c_numpy :
[[-0.15123 0.267693 -0.113416]
[-0.242806 0.216831 -0.118585]
[-0.281794 0.158563 -0.139627]
[-0.296978 0.143559 -0.132519]]
c_torch :
[[[-0.15123 0.267693 -0.113416]
[-0.242806 0.216831 -0.118585]
[-0.281794 0.158563 -0.139627]
[-0.296978 0.143559 -0.132519]]]
---------
dx_numpy :
[[[-0.233995 -0.015056]
[-0.314481 -0.021608]
[-0.176089 0.00849 ]
[-0.314323 0.070071]]
[[-0.207276 -0.014927]
[-0.176493 -0.061383]
[-0.139164 -0.051902]
[-0.209488 0.068137]]
[[-0.104086 -0.052976]
[-0.087995 -0.03669 ]
[-0.094893 -0.003202]
[-0.091384 -0.030927]]]
dx_torch :
[[[-0.233995 -0.015056]
[-0.314481 -0.021608]
[-0.176089 0.00849 ]
[-0.314323 0.070071]]
[[-0.207276 -0.014927]
[-0.176493 -0.061383]
[-0.139164 -0.051902]
[-0.209488 0.068137]]
[[-0.104086 -0.052976]
[-0.087995 -0.03669 ]
[-0.094893 -0.003202]
[-0.091384 -0.030927]]]
---------
w_ih_grad_numpy :
[[-0.309665 -0.286389]
[ 0.119474 0.1197 ]
[-0.060895 -0.053575]
[ 0.081656 0.069246]
[ 0.343434 0.306735]
[ 0.238902 0.180056]
[ 0.799281 0.762242]
[ 1.66807 1.574547]
[ 0.878086 0.861914]
[-0.134765 -0.128521]
[ 0.202337 0.194739]
[ 0.043419 0.034041]]
w_ih_grad_torch :
[[-0.309665 -0.286389]
[ 0.119474 0.1197 ]
[-0.060895 -0.053575]
[ 0.081656 0.069246]
[ 0.343434 0.306735]
[ 0.238902 0.180056]
[ 0.799281 0.762242]
[ 1.66807 1.574547]
[ 0.878086 0.861914]
[-0.134765 -0.128521]
[ 0.202337 0.194739]
[ 0.043419 0.034041]]
---------
w_hh_grad_numpy :
[[-0.080813 -0.141149 -0.13131 ]
[ 0.031354 0.062847 0.063331]
[ 0.011833 -0.017525 -0.007764]
[ 0.082655 0.059767 0.089204]
[ 0.154359 0.176164 0.180268]
[ 0.096947 0.111974 0.105654]
[ 0.185689 0.399495 0.394687]
[ 0.454164 0.778306 0.671739]
[ 0.22264 0.398373 0.417446]
[ 0.006327 -0.046814 -0.022775]
[ 0.070304 0.103367 0.09488 ]
[ 0.029992 0.027785 0.034 ]]
w_hh_grad_torch :
[[-0.080813 -0.141149 -0.13131 ]
[ 0.031354 0.062847 0.063331]
[ 0.011833 -0.017525 -0.007764]
[ 0.082655 0.059767 0.089204]
[ 0.154359 0.176164 0.180268]
[ 0.096947 0.111974 0.105654]
[ 0.185689 0.399495 0.394687]
[ 0.454164 0.778306 0.671739]
[ 0.22264 0.398373 0.417446]
[ 0.006327 -0.046814 -0.022775]
[ 0.070304 0.103367 0.09488 ]
[ 0.029992 0.027785 0.034 ]]
---------
b_ih_grad_numpy :
[-0.545791 0.260244 -0.103962 0.141001 0.622189 0.328654 1.623599
3.035249 1.506 -0.245668 0.394114 0.060554]
b_ih_grad_torch :
[-0.545791 0.260244 -0.103962 0.141001 0.622189 0.328654 1.623599
3.035249 1.506 -0.245668 0.394114 0.060554]
---------
b_hh_grad_numpy :
[-0.545791 0.260244 -0.103962 0.141001 0.622189 0.328654 1.623599
3.035249 1.506 -0.245668 0.394114 0.060554]
b_hh_grad_torch :
[-0.545791 0.260244 -0.103962 0.141001 0.622189 0.328654 1.623599
3.035249 1.506 -0.245668 0.394114 0.060554]
"""